Commit b81b2f59 authored by wanglch's avatar wanglch
Browse files

Initial commit

parent f7c86e68
../../.github/CONTRIBUTING.md
\ No newline at end of file
from argparse import ArgumentParser
from pathlib import Path
import copy
import gradio as gr
import os
import re
import secrets
import tempfile
from PIL import Image
from monkey_model.modeling_monkey import MonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer
from monkey_model.configuration_monkey import MonkeyConfig
import shutil
from pathlib import Path
import json
DEFAULT_CKPT_PATH = '/home/zhangli/demo/'
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
title_markdown = ("""
# Welcome to Monkey
Hello! I'm Monkey, a Large Language and Vision Assistant. Before talking to me, please read the **Operation Guide** and **Terms of Use**.
你好!我是Monkey,一个大型语言和视觉助理。在与我交谈之前,请阅读**操作指南**和**使用条款**。
## Operation Guide 操作指南
Click the **Upload** button to upload an image. Then, you can get Monkey's answer in two ways:点击**Upload**上传图像。你可以通过两种方式得到Monkey的回答:
- Click the **Generate** and Monkey will generate a description of the image. 点击**Generate**,Monkey将生成图像的描述。
- Enter the question in the dialog box, click the **Submit**, and Monkey will answer the question based on the image. 在对话框中输入问题,点击**Submit**,Monkey会根据图片回答问题。
- Click **Clear History** to clear the current image and Q&A content.点击**Clear History**,清除当前图片和问答内容。
> Note: Monkey does not have a multi-round dialogue function. Perhaps we will further develop its capabilities in the future. 注意:Monkey没有多轮对话功能,或许我们在未来会进一步开发它的能力。
> Monkey支持中文,但使用英文提问会比使用中文效果明显好.""")
policy_markdown = ("""
## Terms of Use
By using this service, users are required to agree to the following terms:
- Monkey is for research use only and unauthorized commercial use is prohibited. For any query, please contact the author.
- Monkey's generation capabilities are limited, so we recommend that users do not rely entirely on its answers.
- Monkey's security measures are limited, so we cannot guarantee that the output is completely appropriate. We strongly recommend that users do not intentionally guide Monkey to generate harmful content, including hate speech, discrimination, violence, pornography, deception, etc.
""")
def _get_args():
parser = ArgumentParser()
parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
help="Checkpoint name or path, default to %(default)r")
parser.add_argument("--cpu-only", action="store_true", help="Run demo with CPU only")
parser.add_argument("--share", action="store_true", default=False,
help="Create a publicly shareable link for the interface.")
parser.add_argument("--inbrowser", action="store_true", default=False,
help="Automatically launch the interface in a new tab on the default browser.")
parser.add_argument("--server-port", type=int, default=8000,
help="Demo server port.")
parser.add_argument("--server-name", type=str, default="127.0.0.1",
help="Demo server name.")
args = parser.parse_args()
return args
def _load_model_tokenizer(args):
tokenizer = QWenTokenizer.from_pretrained(
args.checkpoint_path, trust_remote_code=True)
if args.cpu_only:
device_map = "cpu"
else:
device_map = "cuda"
model = MonkeyLMHeadModel.from_pretrained(
args.checkpoint_path,
device_map=device_map,
trust_remote_code=True,
).eval()
# model.generation_config = GenerationConfig.from_pretrained(
# args.checkpoint_path, trust_remote_code=True, resume_download=True,
# )
tokenizer.padding_side = 'left'
tokenizer.pad_token_id = tokenizer.eod_id
return model, tokenizer
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def _launch_demo(args, model, tokenizer):
def predict(_chatbot, task_history):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
question = _parse_text(query)
print("User: " + _parse_text(query))
full_response = ""
img_path = _chatbot[0][0][0]
try:
Image.open(img_path)
except:
response = "Please upload a picture."
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print("Monkey: " + _parse_text(full_response))
return _chatbot
query = f'<img>{img_path}</img> {question} Answer: '
print(query)
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=512,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
print("Monkey: " + _parse_text(full_response))
return _chatbot
def caption(_chatbot, task_history):
query = "Generate the detailed caption in English:"
chat_query = "Generate the detailed caption in English:"
question = _parse_text(query)
print("User: " + _parse_text(query))
full_response = ""
try:
img_path = _chatbot[0][0][0]
Image.open(img_path)
except:
response = "Please upload a picture."
_chatbot.append((None, response))
full_response = _parse_text(response)
task_history.append((None, full_response))
print("Monkey: " + _parse_text(full_response))
return _chatbot
img_path = _chatbot[0][0][0]
query = f'<img>{img_path}</img> {chat_query} '
print(query)
input_ids = tokenizer(query, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
pred = model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=True,
temperature=0.7,
max_new_tokens=250,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=tokenizer.eod_id,
eos_token_id=tokenizer.eod_id,
)
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
_chatbot.append((None, response))
full_response = _parse_text(response)
task_history.append((None, full_response))
print("Monkey: " + _parse_text(full_response))
return _chatbot
def add_text(history, task_history, text):
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
print(history, task_history, text)
return history, task_history, ""
def add_file(history, task_history, file):
history = [((file.name,), None)]
task_history = [((file.name,), None)]
print(history, task_history, file)
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
with gr.Blocks() as demo:
gr.Markdown(title_markdown)
chatbot = gr.Chatbot(label='Monkey', elem_classes="control-height", height=600,avatar_images=("https://ooo.0x0.ooo/2023/11/09/OehsLx.png","https://ooo.0x0.ooo/2023/11/09/OehGBC.png"),layout="bubble",bubble_full_width=False,show_copy_button=True)
query = gr.Textbox(lines=1, label='Input')
task_history = gr.State([])
with gr.Row():
empty_bin = gr.Button("Clear History (清空)")
submit_btn = gr.Button("Submit (提问)")
generate_btn_en = gr.Button("Generate")
addfile_btn = gr.UploadButton("Upload (上传图片)", file_types=["image"])
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
generate_btn_en.click(caption, [chatbot, task_history], [chatbot], show_progress=True)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True,scroll_to_output=True)
gr.Markdown(policy_markdown)
demo.queue().launch(
server_name="0.0.0.0",
server_port=7681
)
def main():
args = _get_args()
model, tokenizer = _load_model_tokenizer(args)
_launch_demo(args, model, tokenizer)
if __name__ == '__main__':
main()
...@@ -5,6 +5,7 @@ from monkey_model.modeling_textmonkey import TextMonkeyLMHeadModel ...@@ -5,6 +5,7 @@ from monkey_model.modeling_textmonkey import TextMonkeyLMHeadModel
from monkey_model.tokenization_qwen import QWenTokenizer from monkey_model.tokenization_qwen import QWenTokenizer
from monkey_model.configuration_monkey import MonkeyConfig from monkey_model.configuration_monkey import MonkeyConfig
from argparse import ArgumentParser from argparse import ArgumentParser
import torch
def _get_args(): def _get_args():
parser = ArgumentParser() parser = ArgumentParser()
...@@ -21,7 +22,7 @@ def _get_args(): ...@@ -21,7 +22,7 @@ def _get_args():
return args return args
args = _get_args() args = _get_args()
checkpoint_path = args.checkpoint_path checkpoint_path = args.checkpoint_path
device_map = "cuda" device_map = "auto"
# Create model # Create model
config = MonkeyConfig.from_pretrained( config = MonkeyConfig.from_pretrained(
checkpoint_path, checkpoint_path,
...@@ -73,7 +74,7 @@ def inference(input_str, input_image): ...@@ -73,7 +74,7 @@ def inference(input_str, input_image):
pred = model.generate( pred = model.generate(
input_ids=input_ids.cuda(), input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(), attention_mask=attention_mask.cuda(),
do_sample=False, do_sample=True,
num_beams=1, num_beams=1,
max_new_tokens=2048, max_new_tokens=2048,
min_new_tokens=1, min_new_tokens=1,
...@@ -86,7 +87,7 @@ def inference(input_str, input_image): ...@@ -86,7 +87,7 @@ def inference(input_str, input_image):
) )
response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=False).strip() response = tokenizer.decode(pred[0][input_ids.size(1):].cpu(), skip_special_tokens=False).strip()
image = Image.open(input_image).convert("RGB").resize((1000,1000)) image = Image.open(input_image).convert("RGB").resize((1000,1000))
font = ImageFont.truetype('NimbusRoman-Regular.otf', 22) font = ImageFont.load_default() # 使用系统默认字体
bboxes = re.findall(r'<box>(.*?)</box>', response, re.DOTALL) bboxes = re.findall(r'<box>(.*?)</box>', response, re.DOTALL)
refs = re.findall(r'<ref>(.*?)</ref>', response, re.DOTALL) refs = re.findall(r'<ref>(.*?)</ref>', response, re.DOTALL)
if len(refs)!=0: if len(refs)!=0:
......
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
ENV DEBIAN_FRONTEND=noninteractive
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
transformers==4.32.0
accelerate
tiktoken
einops
einops_exts
transformers_stream_generator==0.0.4
scipy
pillow
tensorboard
matplotlib
deepspeed
gradio
peft
\ No newline at end of file
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"scheduler": {
"type": "WarmupLR",
"params": {
"warmup_min_lr": "auto",
"warmup_max_lr": "auto",
"warmup_num_steps": "auto"
}
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none",
"pin_memory": true
},
"offload_param": {
"device": "none",
"pin_memory": true
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"steps_per_print": 100,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}
#!/bin/bash #!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1 export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd` DIR=`pwd`
GPUS_PER_NODE=8 CUDA_VISIBLE_DEVICES=2,3
GPUS_PER_NODE=1
NNODES=1 NNODES=1
NODE_RANK=0 NODE_RANK=0
MASTER_ADDR=localhost MASTER_ADDR=localhost
MASTER_PORT=6001 MASTER_PORT=29502
MODEL="/home/wanglch/projects/TextMonkey/TextMonkey_base" # We use the first version of Qwen-VL
MODEL="Qwen/Qwen-VL" # We use the first version of Qwen-VL
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information. # See the section for finetuning in README for more information.
DATA="pathto/data" DATA="/home/wanglch/projects/TextMonkey/Monkey/data/data.json"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \ --nnodes $NNODES \
...@@ -20,12 +22,11 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \ ...@@ -20,12 +22,11 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--master_port $MASTER_PORT" --master_port $MASTER_PORT"
torchrun $DISTRIBUTED_ARGS finetune_multitask_dialouge_doc.py\ torchrun $DISTRIBUTED_ARGS /home/wanglch/projects/TextMonkey/Monkey/finetune_multitask_dialouge_doc.py\
--model_name_or_path $MODEL \ --model_name_or_path $MODEL \
--data_path $DATA \ --data_path $DATA \
--bf16 True \
--fix_vit True \ --fix_vit True \
--output_dir output_model \ --output_dir /home/wanglch/projects/saves/TextMonkey/Train_multi_dcu \
--num_train_epochs 1 \ --num_train_epochs 1 \
--per_device_train_batch_size 2 \ --per_device_train_batch_size 2 \
--per_device_eval_batch_size 1 \ --per_device_eval_batch_size 1 \
...@@ -44,7 +45,7 @@ torchrun $DISTRIBUTED_ARGS finetune_multitask_dialouge_doc.py\ ...@@ -44,7 +45,7 @@ torchrun $DISTRIBUTED_ARGS finetune_multitask_dialouge_doc.py\
--model_max_length 2048 \ --model_max_length 2048 \
--gradient_checkpointing \ --gradient_checkpointing \
--lazy_preprocess True \ --lazy_preprocess True \
--deepspeed finetune/ds_config_zero2.json \ --deepspeed /home/wanglch/projects/TextMonkey/Monkey/finetune/ds_config_zero2.json \
--image_size 896 \ --image_size 896 \
--image_width 896 \ --image_width 896 \
--image_height 896 \ --image_height 896 \
......
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
CUDA_VISIBLE_DEVICES=3,5,6,7
GPUS_PER_NODE=4
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=29517
MODEL="/home/wanglch/projects/TextMonkey/TextMonkey_base" # We use the first version of Qwen-VL
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DATA="/home/wanglch/projects/TextMonkey/Monkey/data/data.json"
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
torchrun $DISTRIBUTED_ARGS /home/wanglch/projects/TextMonkey/Monkey/finetune_multitask_dialouge_doc.py\
--model_name_or_path $MODEL \
--data_path $DATA \
--fp16 True \
--fix_vit True \
--output_dir /home/wanglch/projects/saves/TextMonkey/Train_multi_gpu \
--num_train_epochs 2 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.02 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--gradient_checkpointing \
--lazy_preprocess True \
--deepspeed /home/wanglch/projects/TextMonkey/Monkey/finetune/ds_config_zero2.json \
--image_size 896 \
--image_width 896 \
--image_height 896 \
--add_window true \
--use_global true \
--resampler true \
--use_lora True \
--remain 512
...@@ -24,7 +24,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index ...@@ -24,7 +24,7 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass @dataclass
class ModelArguments: class ModelArguments:
model_name_or_path: Optional[str] = field(default="Qwen/Qwen-7B") model_name_or_path: Optional[str] = field(default="/home/wanglch/projects/TextMonkey/TextMonkey_base")
@dataclass @dataclass
...@@ -330,7 +330,7 @@ def train(): ...@@ -330,7 +330,7 @@ def train():
# Set RoPE scaling factor # Set RoPE scaling factor
config = MonkeyConfig.from_pretrained( config = MonkeyConfig.from_pretrained(
"monkey_model", "/home/wanglch/projects/TextMonkey/TextMonkey_base",
cache_dir=training_args.cache_dir, cache_dir=training_args.cache_dir,
trust_remote_code=True, trust_remote_code=True,
) )
...@@ -362,7 +362,7 @@ def train(): ...@@ -362,7 +362,7 @@ def train():
) )
tokenizer = QWenTokenizer.from_pretrained( tokenizer = QWenTokenizer.from_pretrained(
"monkey_model", "/home/wanglch/projects/TextMonkey/TextMonkey_base",
cache_dir=training_args.cache_dir, cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length, model_max_length=training_args.model_max_length,
padding_side="right", padding_side="right",
...@@ -402,21 +402,23 @@ def train(): ...@@ -402,21 +402,23 @@ def train():
model.lm_head.requires_grad_(False) model.lm_head.requires_grad_(False)
if training_args.use_lora: if training_args.use_lora:
if lora_args.q_lora or "chat" in model_args.model_name_or_path.lower(): model.transformer.requires_grad_(False)
modules_to_save = None model.lm_head.requires_grad_(False)
else: model.transformer.visual.requires_grad_(False)
modules_to_save = []
lora_config = LoraConfig( if hasattr(model.transformer.visual, 'attn_pool'):
r=lora_args.lora_r, model.transformer.visual.attn_pool.requires_grad_(True)
lora_alpha=lora_args.lora_alpha, # only keep the gradient of lora and resampler module
target_modules=lora_args.lora_target_modules, for k, v in model.named_parameters():
lora_dropout=lora_args.lora_dropout, if "lora" in k:
bias=lora_args.lora_bias, v.requires_grad_(True)
task_type="CAUSAL_LM", for k, v in model.named_parameters():
modules_to_save=modules_to_save # This argument serves for adding new tokens. if "window_attention" in k:
) v.requires_grad_(True)
model = get_peft_model(model, lora_config) if training_args.fix_llm and hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
model.transformer.h.requires_grad_(False)
model.transformer.wte.requires_grad_(False)
if training_args.gradient_checkpointing: if training_args.gradient_checkpointing:
model.enable_input_require_grads() model.enable_input_require_grads()
......
#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1
DIR=`pwd`
CUDA_VISIBLE_DEVICES=2,3
GPUS_PER_NODE=2
NNODES=1
NODE_RANK=0
MASTER_ADDR=localhost
MASTER_PORT=29519
# ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations.
# See the section for finetuning in README for more information.
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
torchrun $DISTRIBUTED_ARGS /home/wanglch/projects/TextMonkey/Monkey/finetune_multitask_dialouge_doc.py \
--model_name_or_path /home/wanglch/projects/TextMonkey/TextMonkey_base \
--data_path /home/wanglch/projects/TextMonkey/Monkey/data/data.json \
--fp16 True \
--fix_vit True \
--fix_llm True \
--output_dir /home/wanglch/projects/saves/TextMonkey/Train_multi_dcu \
--num_train_epochs 2 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 10 \
--learning_rate 1e-5 \
--weight_decay 0.1 \
--adam_beta2 0.95 \
--warmup_ratio 0.02 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--report_to "none" \
--model_max_length 2048 \
--gradient_checkpointing \
--lazy_preprocess True \
--deepspeed /home/wanglch/projects/TextMonkey/Monkey/finetune/ds_config_zero2.json \
--image_size 896 \
--image_width 896 \
--image_height 896 \
--add_window true \
--use_global true \
--resampler true \
--use_lora True \
--remain 512
Question,Input Image,flag,username,timestamp
Read all the text in the image.,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/b445973ef3af610d7fd782de98c9024c87fa6089/tmpzy2ualgz.jpg,,,2024-06-25 08:04:24.668143
OCR with grounding,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/a1dd742177e84e78950310040d01161b92346156/tmpvn3mfx6n.jpg,,,2024-06-25 08:09:46.739416
OCR with grounding,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/9bdfec8fa848fbd1b77f92775f7fa599ad5b8bb6/tmpyiolfgnr.jpg,,,2024-06-25 08:15:40.416458
OCR with grounding,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/b58a30506cef8141a652187c8226fd11ef31c2b2/tmp0xwrftr0.png,,,2024-06-25 08:19:13.998547
Read all the text in the image.,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/e1893e80604de8d93761c9c086931017491e6206/tmpyoeac1nf.jpg,,,2024-06-25 09:36:28.584080
ocr这张图片的文字信息,并以json格式返回,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/69feaa2646e24ab2e5f57b9fe172462aa9ed5fca/tmpyggw5lgf.jpg,,,2024-06-25 09:42:04.567537
这是什么,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/456dce7b71ced3cc7333304165fb1d64f76a20bf/tmph__9e7j9.jpg,,,2024-06-25 09:45:15.589882
ocr这张图片的文字信息,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/3a31676bd1d1d8674d4ddfec28ce9fbb88817864/tmpk_5ah8qz.jpg,,,2024-06-25 09:49:44.657171
Read all the text in the image,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/00b0f410c05a56bea4ef0f10faac96f02c510e85/tmpdgp55_39.jpg,,,2024-06-25 09:53:14.966818
Read all the text in the image,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/fc2458835e72faaaba997e88edf0ce68033d3f81/tmp7m9oees7.jpg,,,2024-06-25 09:53:16.438519
OCR with grounding:,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/088a928701d360040ccb3973ef23255894f686ca/tmpnwp_sbet.jpg,,,2024-06-25 09:56:51.982646
Read all the text in the image.,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/12662d0af4fe2d7db3a9d18bfb4196260efe36db/tmpo17kbd8t.jpg,,,2024-06-25 10:02:04.521209
OCR with grounding,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/c1f53d5707fd1db35a0583cee3c86a419db02a20/tmpl8qyqanw.jpg,,,2024-06-25 10:02:52.364145
Read all the text in the image,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/f40f171969a76d98e17b86ca3745df2e0db81f58/tmpzn5c5psa.jpg,,,2024-06-25 10:03:28.890421
ocr这张图片中的文字信息,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/ae81de303bdc16ea3756317e5d41889a414679e1/tmpq1ofrhsz.jpg,,,2024-06-25 10:06:39.108823
Read all the text in the image.,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/34e919ffd1a492e7545d772efcbada1725595604/tmpsu3amqf2.jpg,,,2024-06-25 10:09:11.675886
Read all the text in the image.,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/5dde3e1670ee1dd00d565a8ec98cc2e952b0c149/tmp72is7k_3.jpg,,,2024-06-25 10:12:07.051628
Read all the text in the image.,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/ed60ab97135937aaf18ce3e64355e74115d427ee/tmp_y9evvhz.jpg,,,2024-06-25 10:14:38.290775
ocr收款人信息,出票金额,实际结算金额,申请人,出票行信息,并以json格式返回,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/ce8dd2355cfe0b828cbcb29c9736d91d96f068e2/tmp6d4d59vr.jpg,,,2024-06-26 08:45:54.525564
读取收款人信息,出票金额,实际结算金额,申请人,出票行信息,并以json格式返回,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/9180b70e5022e2382c3203f619b44bce77e2d8a6/tmpsrt32sg9.jpg,,,2024-06-26 08:47:00.371397
收款人信息,出票金额,实际结算金额,申请人,出票行信息,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/1f5481e4a02e0cb417fdf97a4d452e12166922a3/tmp4dmuc013.jpg,,,2024-06-26 08:47:25.486834
申请人是谁,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/4a04c4b87418b2cedb358296193cc28f170809ab/tmpi7ky4d4u.jpg,,,2024-06-26 08:51:40.970529
出票行是?,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/48620f449363a74816cca64204acfaed17f2896f/tmpvye_ndzz.jpg,,,2024-06-26 08:52:30.040664
Read all the text in the image,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/82cfba4aa08b13088c780cb23c995debf0716685/tmp1xxxovg8.jpg,,,2024-06-26 08:53:07.991437
出票日期是什么时候,/home/wanglch/projects/TextMonkey/Monkey/flagged/Input Image/dbf2f3ea4e5202857c9aad2b993d575f552a048f/tmphahxx1_w.jpg,,,2024-06-26 09:02:41.378048
# 模型唯一标识
modelCode = 742
# 模型名称
modelName=text-monkey_pytorch
# 模型描述
modelDescription=多模态OCR大模型
# 应用场景
appScenario=推理,训练,对话问答,金融,教育,政府,交通
# 框架类型
frameType=pytorch
{
"architectures": [
"MonkeyLMHeadModel"
],
"attn_dropout_prob": 0.0,
"auto_map": {
"AutoConfig": "configuration_qwen.QWenConfig",
"AutoModelForCausalLM": "modeling_monkey.MonkeyLMHeadModel"
},
"bf16": true,
"emb_dropout_prob": 0.0,
"fp16": false,
"fp32": false,
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 22016,
"kv_channels": 128,
"layer_norm_epsilon": 1e-06,
"max_position_embeddings": 8192,
"model_type": "monkey",
"no_bias": true,
"num_attention_heads": 32,
"num_hidden_layers": 32,
"onnx_safe": null,
"rotary_emb_base": 10000,
"rotary_pct": 1.0,
"scale_attn_weights": true,
"seq_length": 2048,
"tie_word_embeddings": false,
"tokenizer_type": "QWenTokenizer",
"torch_dtype": "bfloat16",
"transformers_version": "4.32.0",
"use_cache": false,
"use_dynamic_ntk": true,
"use_flash_attn": false,
"use_logn_attn": true,
"visual": {
"heads": 16,
"image_size": 896,
"image_start_id": 151857,
"layers": 48,
"mlp_ratio": 4.9231,
"output_dim": 4096,
"patch_size": 14,
"width": 1664,
"lora_repeat_num":4
},
"vocab_size": 151936
}
\ No newline at end of file
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class MonkeyConfig(PretrainedConfig):
model_type = "monkey"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from transformers import PretrainedConfig
class QWenConfig(PretrainedConfig):
model_type = "monkey"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=151936,
hidden_size=4096,
num_hidden_layers=32,
num_attention_heads=32,
emb_dropout_prob=0.0,
attn_dropout_prob=0.0,
layer_norm_epsilon=1e-6,
initializer_range=0.02,
max_position_embeddings=8192,
scale_attn_weights=True,
use_cache=True,
bf16=False,
fp16=False,
fp32=False,
kv_channels=128,
rotary_pct=1.0,
rotary_emb_base=10000,
use_dynamic_ntk=True,
use_logn_attn=True,
use_flash_attn="auto",
intermediate_size=22016,
no_bias=True,
tie_word_embeddings=False,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.emb_dropout_prob = emb_dropout_prob
self.attn_dropout_prob = attn_dropout_prob
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.scale_attn_weights = scale_attn_weights
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.bf16 = bf16
self.fp16 = fp16
self.fp32 = fp32
self.kv_channels = kv_channels
self.rotary_pct = rotary_pct
self.rotary_emb_base = rotary_emb_base
self.use_dynamic_ntk = use_dynamic_ntk
self.use_logn_attn = use_logn_attn
self.use_flash_attn = use_flash_attn
self.no_bias = no_bias
super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
from monkey_model.modeling_qwen import QWenModel,QWenPreTrainedModel,QWenLMHeadModel
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
logger = logging.get_logger(__name__)
class MonkeyModel(QWenModel):
def __init__(self, config):
super().__init__(config)
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
windows,images_448 = self.visual.encode(images)
patch_list = []
lora_idx = 0
for col in windows:
for image_patch in col:
patch_list.append(self.visual(image_patch,idx=lora_idx))
lora_idx += 1
global_feat = self.visual(images_448)
local_feat = torch.cat(patch_list,dim=1)
images = torch.cat([local_feat,global_feat],dim=1)
assert images.shape[0] == len(images)
else:
images = None
return super().forward(input_ids,
past_key_values,
attention_mask,
token_type_ids,
position_ids,
head_mask,inputs_embeds,
encoder_hidden_states,
encoder_attention_mask,
use_cache,
output_attentions,
output_hidden_states,
return_dict,
images)
class MonkeyLMHeadModel(QWenLMHeadModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = MonkeyModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import importlib
import math
from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.cuda.amp import autocast
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizer, GenerationConfig, StoppingCriteriaList
from transformers.generation.logits_process import LogitsProcessorList
if TYPE_CHECKING:
from transformers.generation.streamers import BaseStreamer
from transformers.generation.utils import GenerateOutput
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
try:
from einops import rearrange
except ImportError:
rearrange = None
from torch import nn
SUPPORT_CUDA = torch.cuda.is_available()
SUPPORT_BF16 = SUPPORT_CUDA and torch.cuda.is_bf16_supported()
SUPPORT_FP16 = SUPPORT_CUDA and torch.cuda.get_device_capability(0)[0] >= 7
from .configuration_qwen import QWenConfig
from .qwen_generation_utils import (
HistoryType,
make_context,
decode_tokens,
get_stop_words_ids,
StopWordsLogitsProcessor,
)
from .visual import VisionTransformer
SUPPORT_TORCH2 = hasattr(torch, '__version__') and int(torch.__version__.split(".")[0]) >= 2
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "qwen"
_CONFIG_FOR_DOC = "QWenConfig"
QWen_PRETRAINED_MODEL_ARCHIVE_LIST = ["qwen-7b"]
_ERROR_BAD_CHAT_FORMAT = """\
We detect you are probably using the pretrained model (rather than chat model) for chatting, since the chat_format in generation_config is not "chatml".
If you are directly using the model downloaded from Huggingface, please make sure you are using our "Qwen/Qwen-7B-Chat" Huggingface model (rather than "Qwen/Qwen-7B") when you call model.chat().
我们检测到您可能在使用预训练模型(而非chat模型)进行多轮chat,因为您当前在generation_config指定的chat_format,并未设置为我们在对话中所支持的"chatml"格式。
如果您在直接使用我们从Huggingface提供的模型,请确保您在调用model.chat()时,使用的是"Qwen/Qwen-7B-Chat"模型(而非"Qwen/Qwen-7B"预训练模型)。
"""
_SENTINEL = object()
_ERROR_STREAM_IN_CHAT = """\
Pass argument `stream` to model.chat() is buggy, deprecated, and marked for removal. Please use model.chat_stream(...) instead of model.chat(..., stream=True).
向model.chat()传入参数stream的用法可能存在Bug,该用法已被废弃,将在未来被移除。请使用model.chat_stream(...)代替model.chat(..., stream=True)。
"""
apply_rotary_emb_func = None
rms_norm = None
# use flash attnetion, if your machine do not support it, you can close it
use_flash_attention = True
def _import_flash_attn():
global apply_rotary_emb_func, rms_norm, flash_attn_unpadded_func
try:
from flash_attn.layers.rotary import apply_rotary_emb_func as __apply_rotary_emb_func
apply_rotary_emb_func = __apply_rotary_emb_func
except ImportError:
logger.warn(
"Warning: import flash_attn rotary fail, please install FlashAttention rotary to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention/tree/main/csrc/rotary"
)
# try:
# from flash_attn.ops.rms_norm import rms_norm as __rms_norm
# rms_norm = __rms_norm
# except ImportError:
# logger.warn(
# "Warning: import flash_attn rms_norm fail, please install FlashAttention layer_norm to get higher efficiency "
# "https://github.com/Dao-AILab/flash-attention/tree/main/csrc/layer_norm"
# )
try:
import flash_attn
if not hasattr(flash_attn, '__version__'):
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
else:
if int(flash_attn.__version__.split(".")[0]) >= 2:
from flash_attn.flash_attn_interface import flash_attn_varlen_func as __flash_attn_unpadded_func
else:
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as __flash_attn_unpadded_func
flash_attn_unpadded_func = __flash_attn_unpadded_func
except ImportError:
logger.warn(
"Warning: import flash_attn fail, please install FlashAttention to get higher efficiency "
"https://github.com/Dao-AILab/flash-attention"
)
class FlashSelfAttention(torch.nn.Module):
def __init__(
self,
causal=False,
softmax_scale=None,
attention_dropout=0.0,
):
super().__init__()
assert flash_attn_unpadded_func is not None, (
"Please install FlashAttention first, " "e.g., with pip install flash-attn"
)
assert (
rearrange is not None
), "Please install einops first, e.g., with pip install einops"
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout_p = attention_dropout
def unpad_input(self, hidden_states, attention_mask):
valid_mask = attention_mask.squeeze(1).squeeze(1).eq(0)
seqlens_in_batch = valid_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(valid_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
hidden_states = hidden_states[indices]
return hidden_states, indices, cu_seqlens, max_seqlen_in_batch
def pad_input(self, hidden_states, indices, batch, seqlen):
output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
dtype=hidden_states.dtype)
output[indices] = hidden_states
return rearrange(output, '(b s) ... -> b s ...', b=batch)
def forward(self, q, k, v, attention_mask=None):
assert all((i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)))
assert all((i.is_cuda for i in (q, k, v)))
batch_size, seqlen_q = q.shape[0], q.shape[1]
seqlen_k = k.shape[1]
seqlen_out = seqlen_q
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
cu_seqlens_q = torch.arange(
0,
(batch_size + 1) * seqlen_q,
step=seqlen_q,
dtype=torch.int32,
device=q.device,
)
if batch_size > 1 and attention_mask is not None:
k, indices_k, cu_seqlens_k, seqlen_k = self.unpad_input(k, attention_mask)
if q.size(0) == v.size(0):
q = q[indices_k]
cu_seqlens_q = cu_seqlens_k
seqlen_q = seqlen_k
v = v[indices_k]
else:
cu_seqlens_k = torch.arange(
0,
(batch_size + 1) * seqlen_k,
step=seqlen_k,
dtype=torch.int32,
device=q.device,
)
if self.training:
assert seqlen_k == seqlen_q
is_causal = self.causal
dropout_p = self.dropout_p
else:
is_causal = seqlen_q == seqlen_k
dropout_p = 0
output = flash_attn_unpadded_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
seqlen_q,
seqlen_k,
dropout_p,
softmax_scale=self.softmax_scale,
causal=is_causal,
)
if batch_size > 1 and attention_mask is not None and seqlen_q == seqlen_k:
output = self.pad_input(output, indices_k, batch_size, seqlen_out)
else:
new_shape = (batch_size, output.shape[0] // batch_size) + output.shape[1:]
output = output.view(new_shape)
return output
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
class QWenAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.seq_length = config.seq_length
self.hidden_size = config.hidden_size
self.split_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.scale_attn_weights = True
self.projection_size = config.kv_channels * config.num_attention_heads
assert self.projection_size % config.num_attention_heads == 0
self.hidden_size_per_attention_head = (
self.projection_size // config.num_attention_heads
)
self.c_attn = nn.Linear(config.hidden_size, 3 * self.projection_size)
self.c_proj = nn.Linear(
config.hidden_size, self.projection_size, bias=not config.no_bias
)
self.is_fp32 = not (config.bf16 or config.fp16)
self.bf16 = config.bf16
self.use_dynamic_ntk = config.use_dynamic_ntk
self.use_logn_attn = config.use_logn_attn
logn_list = [
math.log(i, self.seq_length) if i > self.seq_length else 1
for i in range(1, 32768)
]
self.logn_tensor = torch.tensor(logn_list)[None, :, None, None]
self.attn_dropout = nn.Dropout(config.attn_dropout_prob)
if use_flash_attention:
_import_flash_attn()
self.core_attention_flash = FlashSelfAttention(causal=True, attention_dropout=0)
def _attn(self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[],
value.size(-1) ** 0.5,
dtype=attn_weights.dtype,
device=attn_weights.device,
)
query_length, key_length = query.size(-2), key.size(-2)
# causal_mask = self.bias[
# :, :, key_length - query_length : key_length, :key_length
# ]
# mask_value = torch.finfo(attn_weights.dtype).min
# mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(
# attn_weights.device
# )
# attn_weights = torch.where(
# causal_mask, attn_weights.to(attn_weights.dtype), mask_value
# )
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
attn_output = attn_output.transpose(1, 2)
return attn_output, attn_weights
def _upcast_and_reordered_attn(
self, query, key, value, registered_causal_mask, attention_mask=None, head_mask=None
):
bsz, num_heads, q_seq_len, dk = query.size()
_, _, k_seq_len, _ = key.size()
attn_weights = torch.empty(
bsz * num_heads,
q_seq_len,
k_seq_len,
dtype=torch.float32,
device=query.device,
)
scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
with autocast(enabled=False):
q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(
-1, dk, k_seq_len
)
attn_weights = torch.baddbmm(
attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor
)
attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = registered_causal_mask[
:, :, key_length - query_length : key_length, :key_length
]
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(
attn_weights.device
)
attn_weights = torch.where(causal_mask, attn_weights, mask_value)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if attn_weights.dtype != torch.float32:
raise RuntimeError(
"Error with upcasting, attn_weights does not have dtype torch.float32"
)
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)
if head_mask is not None:
attn_weights = attn_weights * head_mask
attn_output = torch.matmul(attn_weights, value)
return attn_output, attn_weights
def _split_heads(self, tensor, num_heads, attn_head_size):
new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
tensor = tensor.view(new_shape)
return tensor
def _merge_heads(self, tensor, num_heads, attn_head_size):
tensor = tensor.contiguous()
new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,)
return tensor.view(new_shape)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
):
mixed_x_layer = self.c_attn(hidden_states)
query, key, value = mixed_x_layer.split(self.split_size, dim=2)
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)
if rotary_pos_emb is not None:
cur_len = query.shape[1]
rotary_pos_emb = [i[:, -cur_len:, :, :] for i in rotary_pos_emb]
rotary_pos_emb = (rotary_pos_emb,) * 2
q_pos_emb, k_pos_emb = rotary_pos_emb
# Slice the pos emb for current inference
query = apply_rotary_pos_emb(query, q_pos_emb)
key = apply_rotary_pos_emb(key, k_pos_emb)
if layer_past is not None:
past_key, past_value = layer_past[0], layer_past[1]
key = torch.cat((past_key, key), dim=1)
value = torch.cat((past_value, value), dim=1)
if use_cache:
present = (key, value)
else:
present = None
if self.use_logn_attn and not self.training:
if self.logn_tensor.device != query.device or self.logn_tensor.dtype != query.dtype:
self.logn_tensor = self.logn_tensor.to(query.device).type_as(query)
seq_start = key.size(1) - query.size(1)
seq_end = key.size(1)
logn_tensor = self.logn_tensor[:, seq_start:seq_end, :, :]
query = query * logn_tensor.expand_as(query)
if self.training and SUPPORT_TORCH2 and use_flash_attention:
attn_output = self.core_attention_flash(query,key,value)
attn_weight = None
else:
query = query.permute(0, 2, 1, 3)
key = key.permute(0, 2, 1, 3)
value = value.permute(0, 2, 1, 3)
attn_output, attn_weight = self._attn(
query, key, value, registered_causal_mask, attention_mask, head_mask
)
context_layer = self._merge_heads(
attn_output, self.num_heads, self.head_dim
)
attn_output = self.c_proj(context_layer)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weight,)
return outputs
class QWenMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.w1 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
self.w2 = nn.Linear(
config.hidden_size, config.intermediate_size // 2, bias=not config.no_bias
)
ff_dim_in = config.intermediate_size // 2
self.c_proj = nn.Linear(ff_dim_in, config.hidden_size, bias=not config.no_bias)
def forward(self, hidden_states):
a1 = self.w1(hidden_states)
a2 = self.w2(hidden_states)
intermediate_parallel = a1 * F.silu(a2)
output = self.c_proj(intermediate_parallel)
return output
class QWenBlock(nn.Module):
def __init__(self, config):
super().__init__()
hidden_size = config.hidden_size
self.bf16 = config.bf16
self.ln_1 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.attn = QWenAttention(config)
self.ln_2 = RMSNorm(
hidden_size,
eps=config.layer_norm_epsilon,
)
self.mlp = QWenMLP(config)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
rotary_pos_emb: Optional[List[torch.Tensor]] = None,
registered_causal_mask: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
layernorm_output = self.ln_1(hidden_states)
attn_outputs = self.attn(
layernorm_output,
rotary_pos_emb,
registered_causal_mask=registered_causal_mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0]
outputs = attn_outputs[1:]
residual = hidden_states
layernorm_input = attn_output + residual
layernorm_output = self.ln_2(layernorm_input)
residual = layernorm_input
mlp_output = self.mlp(layernorm_output)
hidden_states = residual + mlp_output
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
return outputs
class QWenPreTrainedModel(PreTrainedModel):
config_class = QWenConfig
base_model_prefix = "transformer"
is_parallelizable = False
supports_gradient_checkpointing = True
_no_split_modules = ["QWenBlock"]
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
def _init_weights(self, module):
'''
There is no need to re_init
'''
return
"""Initialize the weights."""
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, RMSNorm):
module.weight.data.fill_(1.0)
for name, p in module.named_parameters():
if name == "c_proj.weight":
p.data.normal_(
mean=0.0,
std=(
self.config.initializer_range
/ math.sqrt(2 * self.config.num_hidden_layers)
),
)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, QWenModel):
module.gradient_checkpointing = value
class QWenModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
super().__init__(config)
self.vocab_size = config.vocab_size
self.num_hidden_layers = config.num_hidden_layers
self.embed_dim = config.hidden_size
self.gradient_checkpointing = False
self.use_dynamic_ntk = config.use_dynamic_ntk
self.seq_length = config.seq_length
self.wte = nn.Embedding(self.vocab_size, self.embed_dim)
self.drop = nn.Dropout(config.emb_dropout_prob)
if config.rotary_pct == 1.0:
self.rotary_ndims = None
else:
assert config.rotary_pct < 1
self.rotary_ndims = int(
config.kv_channels * config.rotary_pct
)
dim = (
self.rotary_ndims
if self.rotary_ndims is not None
else config.kv_channels
)
self.rotary_emb = RotaryEmbedding(dim, base=config.rotary_emb_base)
self.use_flash_attn = config.use_flash_attn
self.is_fp32 = not (config.bf16 or config.fp16)
self.registered_causal_mask = None
# if (
# self.use_flash_attn
# and flash_attn_unpadded_func is not None
# and not self.is_fp32
# ):
# self.registered_causal_mask = None
# else:
# max_positions = config.max_position_embeddings
# self.register_buffer(
# "registered_causal_mask",
# torch.tril(
# torch.ones((max_positions, max_positions), dtype=torch.bool)
# ).view(1, 1, max_positions, max_positions),
# persistent=False,
# )
self.h = nn.ModuleList(
[
QWenBlock(
config
)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = RMSNorm(
self.embed_dim,
eps=config.layer_norm_epsilon,
)
self.visual = VisionTransformer(**config.visual)
self.post_init()
def get_input_embeddings(self):
return self.wte
def set_input_embeddings(self, new_embeddings):
self.wte = new_embeddings
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape,
inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
inputs_embeds.device
)
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
images=None
):
if images is None:
if past_key_values is None and torch.any(input_ids == self.config.visual['image_start_id']):
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
images = []
for i, a, b in img_pos:
image = input_ids[i][a + 1 : b - 1].tolist()
image = image[ : image.index(self.config.visual['image_start_id'] + 2)]
images.append(bytes(image).decode('utf-8'))
images = self.visual.encode(images)
assert images.shape[0] == len(images)
else:
images = None
else:
bos_pos = torch.where(input_ids == self.config.visual['image_start_id'])
eos_pos = torch.where(input_ids == self.config.visual['image_start_id'] + 1)
assert (bos_pos[0] == eos_pos[0]).all()
img_pos = torch.stack((bos_pos[0], bos_pos[1], eos_pos[1]), dim=1)
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
batch_size = input_ids.shape[0]
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
batch_size = inputs_embeds.shape[0]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1])
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(
past_length,
input_shape[-1] + past_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
encoder_attention_mask = None
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_length
)
hidden_states = inputs_embeds
kv_seq_len = hidden_states.size()[1]
if past_key_values[0] is not None:
# past key values[0][0] shape: bs * seq_len * head_num * dim
kv_seq_len += past_key_values[0][0].shape[1]
if (
self.use_dynamic_ntk
and kv_seq_len == hidden_states.size()[1]
and not self.training
):
context_value = math.log(kv_seq_len / self.seq_length, 2) + 1
ntk_alpha = 2 ** math.ceil(context_value) - 1
ntk_alpha = max(ntk_alpha, 1)
else:
ntk_alpha = self.rotary_emb._ntk_alpha_cached
rotary_pos_emb = self.rotary_emb(kv_seq_len, ntk_alpha=ntk_alpha)
for idx in range(len(rotary_pos_emb)):
rotary_pos_emb[idx] = rotary_pos_emb[idx].to(hidden_states.device)
hidden_states = self.drop(hidden_states)
if images is not None:
for idx, (i, a, b) in enumerate(img_pos):
hidden_states[i][a + 1 : b] = images[idx]
output_shape = input_shape + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
rotary_pos_emb,
self.registered_causal_mask,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
)
else:
outputs = block(
hidden_states,
layer_past=layer_past,
rotary_pos_emb=rotary_pos_emb,
registered_causal_mask=self.registered_causal_mask,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
hidden_states = self.ln_f(hidden_states)
hidden_states = hidden_states.view(output_shape)
# Add last hidden state
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v for v in [hidden_states, presents, all_hidden_states] if v is not None
)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=presents,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
)
class QWenLMHeadModel(QWenPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.rotary_emb\.inv_freq"]
_keys_to_ignore_on_load_unexpected = [r"h\.\d+\.attn\.masked_bias"]
def __init__(self, config):
super().__init__(config)
assert (
config.bf16 + config.fp16 + config.fp32 <= 1
), "Only one of \"bf16\", \"fp16\", \"fp32\" can be true"
autoset_precision = config.bf16 + config.fp16 + config.fp32 == 0
if autoset_precision:
if SUPPORT_BF16:
logger.warn(
"The model is automatically converting to bf16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.bf16 = True
elif SUPPORT_FP16:
logger.warn(
"The model is automatically converting to fp16 for faster inference. "
"If you want to disable the automatic precision, please manually add bf16/fp16/fp32=True to \"AutoModelForCausalLM.from_pretrained\"."
)
config.fp16 = True
else:
config.fp32 = True
if config.bf16 and SUPPORT_CUDA and not SUPPORT_BF16:
logger.warn("Your device does NOT seem to support bf16, you can switch to fp16 or fp32 by by passing fp16/fp32=True in \"AutoModelForCausalLM.from_pretrained\".")
if config.fp16 and SUPPORT_CUDA and not SUPPORT_FP16:
logger.warn("Your device does NOT support faster inference with fp16, please switch to fp32 which is likely to be faster")
if config.fp32:
if SUPPORT_BF16:
logger.warn("Your device support faster inference by passing bf16=True in \"AutoModelForCausalLM.from_pretrained\".")
elif SUPPORT_FP16:
logger.warn("Your device support faster inference by passing fp16=True in \"AutoModelForCausalLM.from_pretrained\".")
self.transformer = QWenModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.bf16:
self.transformer.bfloat16()
self.lm_head.bfloat16()
if config.fp16:
self.transformer.half()
self.lm_head.half()
self.post_init()
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs
):
token_type_ids = kwargs.get("token_type_ids", None)
if past_key_values:
input_ids = input_ids[:, -1].unsqueeze(-1)
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)
else:
position_ids = None
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
"token_type_ids": token_type_ids,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
labels = labels.to(lm_logits.device)
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=lm_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
return tuple(
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
)
for layer_past in past_key_values
)
def chat(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
append_history: bool = True,
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=generation_config,
**kwargs,
)
response = decode_tokens(
outputs[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
chat_format=generation_config.chat_format,
verbose=False,
errors='replace'
)
if append_history:
history.append((query, response))
return response, history
def chat_pretrain(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
append_history: bool = False,
stream: Optional[bool] = _SENTINEL,
stop_words_ids: Optional[List[List[int]]] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Tuple[str, HistoryType]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert stream is _SENTINEL, _ERROR_STREAM_IN_CHAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
input_ids = torch.tensor([context_tokens]).to(self.device)
outputs = self.generate(
input_ids,
stop_words_ids=stop_words_ids,
return_dict_in_generate=False,
generation_config=generation_config,
**kwargs,
)
response = decode_tokens(
outputs[0],
tokenizer,
raw_text_len=len(raw_text),
context_length=len(context_tokens),
chat_format=generation_config.chat_format,
verbose=False,
errors='replace'
)
if append_history:
history.append((query, response))
return response, history
def chat_stream(
self,
tokenizer: PreTrainedTokenizer,
query: str,
history: Optional[HistoryType],
system: str = "You are a helpful assistant.",
stop_words_ids: Optional[List[List[int]]] = None,
logits_processor: Optional[LogitsProcessorList] = None,
generation_config: Optional[GenerationConfig] = None,
**kwargs,
) -> Generator[str, Any, None]:
generation_config = generation_config if generation_config is not None else self.generation_config
assert generation_config.chat_format == 'chatml', _ERROR_BAD_CHAT_FORMAT
if history is None:
history = []
if stop_words_ids is None:
stop_words_ids = []
max_window_size = kwargs.get('max_window_size', None)
if max_window_size is None:
max_window_size = generation_config.max_window_size
raw_text, context_tokens = make_context(
tokenizer,
query,
history=history,
system=system,
max_window_size=max_window_size,
chat_format=generation_config.chat_format,
)
stop_words_ids.extend(get_stop_words_ids(
generation_config.chat_format, tokenizer
))
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
input_ids = torch.tensor([context_tokens]).to(self.device)
from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig
self.__class__.generate_stream = NewGenerationMixin.generate
self.__class__.sample_stream = NewGenerationMixin.sample_stream
stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True)
def stream_generator():
outputs = []
for token in self.generate_stream(
input_ids,
return_dict_in_generate=False,
generation_config=stream_config,
logits_processor=logits_processor,
seed=-1,
**kwargs):
outputs.append(token.item())
yield tokenizer.decode(outputs, skip_special_tokens=True, errors='ignore')
return stream_generator()
def generate(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
generation_config = generation_config if generation_config is not None else self.generation_config
# Process stop_words_ids.
stop_words_ids = kwargs.pop("stop_words_ids", None)
if stop_words_ids is None and generation_config is not None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is None:
stop_words_ids = getattr(generation_config, "stop_words_ids", None)
if stop_words_ids is not None:
stop_words_logits_processor = StopWordsLogitsProcessor(
stop_words_ids=stop_words_ids,
eos_token_id=generation_config.eos_token_id,
)
if logits_processor is None:
logits_processor = LogitsProcessorList([stop_words_logits_processor])
else:
logits_processor.append(stop_words_logits_processor)
return super().generate(
inputs,
generation_config=generation_config,
logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
synced_gpus=synced_gpus,
assistant_model=assistant_model,
streamer=streamer,
**kwargs,
)
class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
self.dim = dim
self.base = base
self.inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
if importlib.util.find_spec("einops") is None:
raise RuntimeError("einops is required for Rotary Embedding")
self._rotary_pos_emb_cache = None
self._seq_len_cached = 0
self._ntk_alpha_cached = 1.0
def update_rotary_pos_emb_cache(self, max_seq_len, offset=0, ntk_alpha=1.0):
seqlen = max_seq_len + offset
if seqlen > self._seq_len_cached or ntk_alpha != self._ntk_alpha_cached:
base = self.base * ntk_alpha ** (self.dim / (self.dim - 2))
self.inv_freq = 1.0 / (
base
** (
torch.arange(0, self.dim, 2, device=self.inv_freq.device).float()
/ self.dim
)
)
self._seq_len_cached = max(2 * seqlen, 16)
self._ntk_alpha_cached = ntk_alpha
seq = torch.arange(self._seq_len_cached, device=self.inv_freq.device)
freqs = torch.outer(seq.type_as(self.inv_freq), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
from einops import rearrange
emb = rearrange(emb, "n d -> 1 n 1 d")
cos, sin = emb.cos(), emb.sin()
self._rotary_pos_emb_cache = [cos, sin]
def forward(self, max_seq_len, offset=0, ntk_alpha=1.0):
self.update_rotary_pos_emb_cache(max_seq_len, offset, ntk_alpha)
cos, sin = self._rotary_pos_emb_cache
return [cos[:, offset : offset + max_seq_len], sin[:, offset : offset + max_seq_len]]
def _rotate_half(x):
from einops import rearrange
x = rearrange(x, "... (j d) -> ... j d", j=2)
x1, x2 = x.unbind(dim=-2)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(t, freqs):
cos, sin = freqs
if apply_rotary_emb_func is not None and t.is_cuda:
t_ = t.float()
cos = cos.squeeze(0).squeeze(1)[:, : cos.shape[-1] // 2]
sin = sin.squeeze(0).squeeze(1)[:, : sin.shape[-1] // 2]
output = apply_rotary_emb_func(t_, cos, sin).type_as(t)
return output
else:
rot_dim = freqs[0].shape[-1]
cos, sin = freqs
t_, t_pass_ = t[..., :rot_dim], t[..., rot_dim:]
t_ = t_.float()
t_pass_ = t_pass_.float()
t_ = (t_ * cos) + (_rotate_half(t_) * sin)
return torch.cat((t_, t_pass_), dim=-1).type_as(t)
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
if rms_norm is not None and x.is_cuda:
return rms_norm(x, self.weight, self.eps)
else:
output = self._norm(x.float()).type_as(x)
return output * self.weight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment