Commit 58d33d4c authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #1904 canceled with stages
#!/bin/bash
# Change for multinode config
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=2
MASTER_ADDR=127.0.0.1
MASTER_PORT=29513
CUDA_VISIBLE_DEVICES=0,1
# GPUS_PER_NODE=1
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
# change LOAD to your local path of DocOwl1.5-Omni
# batch size = per_device_train_batch_size x GPUS_PER_NODE x NNODES x gradient_accumulation_steps
torchrun $DISTRIBUTED_ARGS mplug_docowl/train/train_docowl.py \
--deepspeed '/home/wanglch/projects/mPLUG-DocOwl1.5-Omni/scripts/zero2.json' \
--model_name_or_path '/home/wanglch/projects/mPLUG-DocOwl1.5-Omni/DocOwl1.5-Omni-base' \
--version v1 \
--data_path '/home/wanglch/projects/mPLUG-DocOwl1.5-Omni/DocLocal4K/mini_imges.jsonl' \
--image_folder '/home/wanglch/projects/mPLUG-DocOwl1.5-Omni/DocLocal4K' \
--image_size 448 \
--crop_anchors 'grid_9' \
--add_global_img True \
--add_textual_crop_indicator True \
--fp16 True \
--output_dir '/home/wanglch/projects/saves/DocOwl1.5/train_multi_dcu' \
--num_train_epochs 10 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 4 \
--learning_rate 1e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 3600 \
--gradient_checkpointing True \
--tune_vision2text True \
--freeze_vision_model True \
--freeze_backbone True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to tensorboard
#!/bin/bash
# Change for multinode config
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=4
MASTER_ADDR=127.0.0.1
MASTER_PORT=29513
CUDA_VISIBLE_DEVICES=1,3,4,7
# GPUS_PER_NODE=1
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"
# change LOAD to your local path of DocOwl1.5-Omni
# batch size = per_device_train_batch_size x GPUS_PER_NODE x NNODES x gradient_accumulation_steps
torchrun $DISTRIBUTED_ARGS ./mplug_docowl/train/train_docowl.py \
--lora_enable True --lora_r 128 --lora_alpha 256 --vision2text_lr 2e-5 \
--deepspeed './scripts/zero2.json' \
--model_name_or_path './DocOwl1.5-Omni-base' \
--version v1 \
--data_path './DocLocal4K/mini_imges.jsonl' \
--image_folder './DocLocal4K/' \
--image_size 448 \
--crop_anchors 'grid_9' \
--add_global_img True \
--add_textual_crop_indicator True \
--fp16 True \
--output_dir './train_multi_dcu' \
--num_train_epochs 10 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 500 \
--save_total_limit 4 \
--learning_rate 1e-4 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 False \
--model_max_length 3600 \
--gradient_checkpointing True \
--tune_vision2text True \
--freeze_vision_model True \
--freeze_backbone True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
--report_to tensorboard
# 模型唯一标识
modelCode = 1080
# 模型名称
modelName=mplug-docowl_pytorch
# 模型描述
modelDescription=多模态OCR大模型,端侧可用
# 应用场景
appScenario=推理,OCR,金融,教育,政府,科研,交通,广媒
# 框架类型
frameType=pytorch
"""
A model worker executes the model.
"""
import argparse
import asyncio
import json
import time
import threading
import uuid
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse
import requests
import torch
import uvicorn
from functools import partial
from mplug_docowl.utils import (build_logger, server_error_msg,
pretty_print_semaphore)
from mplug_docowl.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN,WORKER_HEART_BEAT_INTERVAL
from mplug_docowl.conversation import conv_templates, SeparatorStyle
from mplug_docowl.model.builder import load_pretrained_model
from mplug_docowl.mm_utils import load_image_from_base64, process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from mplug_docowl.processor import DocProcessor
from transformers import TextIteratorStreamer
from threading import Thread
from icecream import ic
GB = 1 << 30
worker_id = str(uuid.uuid4())[:6]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
global_counter = 0
model_semaphore = None
def heart_beat_worker(controller):
while True:
time.sleep(WORKER_HEART_BEAT_INTERVAL)
controller.send_heart_beat()
class ModelWorker:
def __init__(self,
model_path, model_base, model_name,
resolution, anchors, add_global_img,
load_8bit, load_4bit, device):
if model_path.endswith("/"):
model_path = model_path[:-1]
self.model_name = get_model_name_from_path(model_path)
self.device = device
logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
self.tokenizer, self.model, _, self.context_len = load_pretrained_model(
model_path, model_base, self.model_name, load_8bit, load_4bit, device=self.device)
self.resolution=resolution
self.token_num_each_img = (self.resolution/14)*(self.resolution/14)/self.model.get_model().vision2text.conv_patch
self.doc_image_processor = DocProcessor(image_size=resolution, anchors=anchors, add_global_img=add_global_img, add_textual_crop_indicator=True)
self.is_multimodal = True
@torch.inference_mode()
def generate_stream(self, params):
tokenizer, model = self.tokenizer, self.model
prompt = params["prompt"]
ori_prompt = prompt
images = params.get("images", None)
num_image_tokens = 0
if images is not None and len(images) > 0 and self.is_multimodal:
if len(images) > 0:
images = [load_image_from_base64(image) for image in images]
# docowl only support 1 image, so only keep the last image
image = images[-1]
assert prompt.count(DEFAULT_IMAGE_TOKEN) == 1
images, patch_positions, prompt = self.doc_image_processor(images=image, query=prompt)
images = images.to(self.model.device, dtype=torch.float16)
patch_positions = patch_positions.to(self.model.device)
replace_token = DEFAULT_IMAGE_TOKEN
prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
num_image_tokens = prompt.count(replace_token) * (self.token_num_each_img+1)
else:
images = None
patch_positions = None
image_args = {"images": images, "patch_positions":patch_positions}
else:
images = None
image_args = {}
temperature = float(params.get("temperature", 1.0))
top_p = float(params.get("top_p", 1.0))
# max_context_length = getattr(model.config, 'max_position_embeddings', 4096)
max_context_length = 4096
max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
stop_str = params.get("stop", None)
# do_sample = True if temperature > 0.001 else False
do_sample = False
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
ic(max_context_length, input_ids.shape[-1], num_image_tokens, max_new_tokens)
if max_new_tokens < 1:
yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode()
return
thread = Thread(target=model.generate, kwargs=dict(
inputs=input_ids,
do_sample=do_sample,
temperature=temperature,
# top_p=top_p,
max_new_tokens=max_new_tokens,
streamer=streamer,
stopping_criteria=[stopping_criteria],
use_cache=True,
**image_args
))
thread.start()
generated_text = ori_prompt
for new_text in streamer:
generated_text += new_text
if generated_text.endswith(stop_str):
generated_text = generated_text[:-len(stop_str)]
# yield json.dumps({"text": generated_text, "error_code": 0}).encode()
# replace < > to [ ] to avoide <doc>,<md>,<ocr>,<bbox> are removed by web code
yield json.dumps({"text": generated_text.replace('<','[').replace('>',']'), "error_code": 0}).encode()
def generate_stream_gate(self, params):
try:
for x in self.generate_stream(params):
yield x
except ValueError as e:
print("Caught ValueError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode()
except torch.cuda.CudaError as e:
print("Caught torch.cuda.CudaError:", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode()
except Exception as e:
print("Caught Unknown Error", e)
ret = {
"text": server_error_msg,
"error_code": 1,
}
yield json.dumps(ret).encode()
from .model import MPLUGDocOwlLlamaForCausalLM
from .processor import DocProcessor
\ No newline at end of file
CONTROLLER_HEART_BEAT_EXPIRATION = 30
WORKER_HEART_BEAT_INTERVAL = 15
LOGDIR = "./demo_logs"
# Model Constants
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<|image|>"
import dataclasses
from enum import auto, Enum
from typing import List, Tuple
from mplug_docowl.constants import DEFAULT_IMAGE_TOKEN
class SeparatorStyle(Enum):
"""Different separator style."""
SINGLE = auto()
TWO = auto()
TWO_NO_SYS = auto()
MPT = auto()
PLAIN = auto()
LLAMA_2 = auto()
@dataclasses.dataclass
class Conversation:
"""A class that keeps all conversation history."""
system: str
roles: List[str]
messages: List[List[str]]
offset: int
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
sep: str = "###"
sep2: str = None
version: str = "Unknown"
skip_next: bool = False
def get_prompt(self):
messages = self.messages
if len(messages) > 0 and type(messages[0][1]) is tuple:
messages = self.messages.copy()
init_role, init_msg = messages[0].copy()
# init_msg = init_msg[0].replace("<image>", "").strip()
# if 'mmtag' in self.version:
# messages[0] = (init_role, init_msg)
# messages.insert(0, (self.roles[0], "<Image><image></Image>"))
# messages.insert(1, (self.roles[1], "Received."))
# else:
# messages[0] = (init_role, "<image>\n" + init_msg)
init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
if self.sep_style == SeparatorStyle.SINGLE:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + self.sep
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO:
seps = [self.sep, self.sep2]
ret = self.system + seps[0]
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
seps = [self.sep, self.sep2]
ret = ""
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + ": " + message + seps[i % 2]
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.MPT:
ret = self.system + self.sep
for role, message in messages:
if message:
if type(message) is tuple:
message, _, _ = message
ret += role + message + self.sep
else:
ret += role
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
assert role == self.roles[0], "first message should come from user"
if message:
if type(message) is tuple:
message, _, _ = message
if i == 0: message = wrap_sys(self.system) + message
if i % 2 == 0:
message = wrap_inst(message)
ret += self.sep + message
else:
ret += " " + message + " " + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
elif self.sep_style == SeparatorStyle.PLAIN:
seps = [self.sep, self.sep2]
ret = self.system
for i, (role, message) in enumerate(messages):
if message:
if type(message) is tuple:
message, _, _ = message
ret += message + seps[i % 2]
else:
ret += ""
else:
raise ValueError(f"Invalid style: {self.sep_style}")
return ret
def append_message(self, role, message):
self.messages.append([role, message])
def get_images(self, return_pil=False):
images = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
from PIL import Image
msg, image, image_process_mode = msg
if image_process_mode == "Pad":
def expand2square(pil_img, background_color=(122, 116, 104)):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image)
elif image_process_mode in ["Default", "Crop"]:
pass
elif image_process_mode == "Resize336":
image = image.resize((336, 336))
elif image_process_mode == "Resize":
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if longest_edge != max(image.size):
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
else:
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
if return_pil:
images.append(image)
else:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
images.append(img_b64_str)
return images
def to_gradio_chatbot(self):
ret = []
for i, (role, msg) in enumerate(self.messages[self.offset:]):
if i % 2 == 0:
if type(msg) is tuple:
import base64
from io import BytesIO
msg, image, image_process_mode = msg
max_hw, min_hw = max(image.size), min(image.size)
aspect_ratio = max_hw / min_hw
max_len, min_len = 800, 400
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
longest_edge = int(shortest_edge * aspect_ratio)
W, H = image.size
if H > W:
H, W = longest_edge, shortest_edge
else:
H, W = shortest_edge, longest_edge
image = image.resize((W, H))
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
msg = img_str + msg.replace('<|image|>', '').strip()
ret.append([msg, None])
else:
ret.append([msg, None])
else:
ret[-1][-1] = msg
return ret
def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
offset=self.offset,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
version=self.version)
def dict(self):
if len(self.get_images()) > 0:
return {
"system": self.system,
"roles": self.roles,
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
return {
"system": self.system,
"roles": self.roles,
"messages": self.messages,
"offset": self.offset,
"sep": self.sep,
"sep2": self.sep2,
}
conv_vicuna_v0 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
("Assistant",
"Renewable energy sources are those that can be replenished naturally in a relatively "
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
"renewable and non-renewable energy sources:\n"
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
"energy sources are finite and will eventually run out.\n"
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
"and other negative effects.\n"
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
"have lower operational costs than non-renewable sources.\n"
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
"locations than non-renewable sources.\n"
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
),
offset=2,
sep_style=SeparatorStyle.SINGLE,
sep="###",
)
conv_vicuna_v1 = Conversation(
system="A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO,
sep=" ",
sep2="</s>",
)
conv_mplug_owl2 = Conversation(
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("USER", "ASSISTANT"),
version="v1",
messages=(),
offset=0,
sep_style=SeparatorStyle.TWO_NO_SYS,
sep=" ",
sep2="</s>",
)
# default_conversation = conv_vicuna_v1
default_conversation = conv_mplug_owl2
conv_templates = {
"default": conv_vicuna_v0,
"v0": conv_vicuna_v0,
"v1": conv_vicuna_v1,
"vicuna_v1": conv_vicuna_v1,
"mplug_owl2": conv_mplug_owl2,
}
if __name__ == "__main__":
print(default_conversation.get_prompt())
\ No newline at end of file
from PIL import Image
from io import BytesIO
import base64
import torch
from transformers import StoppingCriteria
from mplug_docowl.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
from icecream import ic
def load_image_from_base64(image):
return Image.open(BytesIO(base64.b64decode(image)))
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
def process_images(images, image_processor, model_cfg=None):
if model_cfg is not None:
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
else:
image_aspect_ratio = 'resize'
new_images = []
if image_aspect_ratio == 'pad':
for image in images:
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
elif image_aspect_ratio == 'resize':
for image in images:
max_edge = max(image.size)
image = image.resize((max_edge, max_edge))
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
new_images.append(image)
else:
return image_processor(images, return_tensors='pt')['pixel_values']
if all(x.shape == new_images[0].shape for x in new_images):
new_images = torch.stack(new_images, dim=0)
return new_images
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
def insert_separator(X, sep):
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
input_ids = []
offset = 0
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
offset = 1
input_ids.append(prompt_chunks[0][0])
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
input_ids.extend(x[offset:])
if return_tensors is not None:
if return_tensors == 'pt':
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f'Unsupported tensor type: {return_tensors}')
return input_ids
def get_model_name_from_path(model_path):
model_path = model_path.strip("/")
model_paths = model_path.split("/")
if model_paths[-1].startswith('checkpoint-'):
return model_paths[-2] + "_" + model_paths[-1]
else:
return model_paths[-1]
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords, tokenizer, input_ids):
self.keywords = keywords
self.keyword_ids = []
self.max_keyword_len = 0
for keyword in keywords:
cur_keyword_ids = tokenizer(keyword).input_ids
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
cur_keyword_ids = cur_keyword_ids[1:]
if len(cur_keyword_ids) > self.max_keyword_len:
self.max_keyword_len = len(cur_keyword_ids)
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
self.tokenizer = tokenizer
self.start_len = input_ids.shape[1]
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
for keyword_id in self.keyword_ids:
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
return True
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
for keyword in self.keywords:
if keyword in outputs:
return True
return False
\ No newline at end of file
from .modeling_mplug_docowl import MPLUGDocOwlLlamaForCausalLM
from .configuration_mplug_docowl import MPLUGDocOwlConfig
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment