Commit fe851fbc authored by zhouxiang's avatar zhouxiang
Browse files

0.2.6版本新增文件补充

parent e2d98ddc
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import collections
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
class UserRequestQueue:
"""Inner group user request queues."""
def __init__(self, group: str, user_id_map: dict):
self.group = group
self.user_queue_map = dict()
self.user_quota_map = dict()
self.user_id_maps = user_id_map
total_quota = 0
for item in user_id_map:
total_quota += item['quota_pct']
for item in user_id_map:
user_id = item['id']
self.user_queue_map[user_id] = collections.deque()
self.user_quota_map[user_id] = item['quota_pct'] / total_quota
def enqueue(self, request_event):
"""Enqueue request to corresponding user queue."""
if request_event[0].user_id in self.user_queue_map:
self.user_queue_map[request_event[0].user_id].append(request_event)
else:
self.user_queue_map['default'].append(request_event)
def empty(self):
"""Whether all user queues are empty."""
for _, req_queue in self.user_queue_map.items():
if len(req_queue) != 0:
return False
return True
def dequeue(self, usage_stats):
"""Dequeue the request to serve."""
uid_to_serve = self.user_to_serve(usage_stats)
if uid_to_serve in self.user_queue_map:
return self.user_queue_map[uid_to_serve].popleft()
return None
def user_to_serve(self, usage_stats):
"""Inner group scheduling.
Find the user to serve from user request queues.
"""
min_usage = 100
uid_to_serve = ''
for uid, req_queue in self.user_queue_map.items():
if len(req_queue) == 0:
continue
# TODO: include token length
# Calculate current user's actual used share and quota share
user_usage, _, group_usage, _ = usage_stats.get_user_usage(
uid, self.group)
actual_share = (user_usage / group_usage) if group_usage > 0 else 0
due_share = self.user_quota_map[uid]
# Serve the user with the relatively least usage share
curr_usage = (actual_share / due_share) if due_share > 0 else 0
if curr_usage == 0:
return uid
if curr_usage < min_usage:
uid_to_serve = uid
min_usage = curr_usage
return uid_to_serve
{
"enable_user_qos": 1,
"user_groups": ["Platinum", "Gold", "Silver", "Bronze"],
"user_group_map": {
"Platinum": [
{
"id": "user_id0",
"quota_pct": 100
},
{
"id": "default",
"quota_pct": 0
}
],
"Gold": [
{
"id": "user_id1",
"quota_pct": 50
},
{
"id": "user_id2",
"quota_pct": 50
},
{
"id": "default",
"quota_pct": 0
}
],
"Silver": [
{
"id": "user_id3",
"quota_pct": 5
},
{
"id": "default",
"quota_pct": 95
}
],
"Bronze": [
{
"id": "user_id4",
"quota_pct": 30
},
{
"id": "user_id5",
"quota_pct": 30
},
{
"id": "user_id6",
"quota_pct": 40
},
{
"id": "default",
"quota_pct": 0
}
]
}
}
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import json
import threading
import time
from typing import List
from lmdeploy.serve.openai.protocol import (ChatCompletionRequestQos,
CompletionRequestQos,
GenerateRequestQos)
from lmdeploy.serve.qos_engine.inner_group_schd import UserRequestQueue
from lmdeploy.serve.qos_engine.usage_stats import UsageStats
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
class QosConfig:
"""qos config class: parse qosconfig for qos engine."""
def __init__(self, qos_tag=''):
qos_config = json.loads(qos_tag)
self.is_qos_enabled = qos_config.get('enable_user_qos', False)
logger.debug(f'is_qos_enabled: {self.is_qos_enabled}')
if self.is_qos_enabled:
self.user_id_maps = qos_config['user_group_map']
self.user_group_prio = qos_config['user_groups']
logger.debug(f'user_id_maps: {self.user_id_maps}')
logger.debug(f'user_group_prio: {self.user_group_prio}')
class QosEngine:
"""impl for qos engine, docs/en/qos.md."""
def __init__(self, qos_tag='', engine=None, **kwargs) -> None:
self.engine = engine
self.availSlots = engine.instance_num
self._stop_event = threading.Event()
self._dequeue_thread = threading.Thread(target=self._serve,
daemon=True)
self.qos_config = QosConfig(qos_tag)
self.qos_user_group = QosGroupQueue(self.qos_config)
self.usage_stats = UsageStats(
total_duration=60,
buffer_count=6,
start_index=0,
user_groups=self.qos_config.user_group_prio)
self.user_served_reqs = dict()
self._dump_stats_thread = threading.Thread(target=self._dump_stats,
daemon=True)
self.lock = threading.Lock()
self.stats_lock = threading.Lock()
def start(self):
"""start qos engine."""
if self.is_qos_enabled():
self._dequeue_thread.start()
self._dump_stats_thread.start()
def is_qos_enabled(self):
"""check while qos engine is enabled."""
return self.qos_config.is_qos_enabled
async def stop_session(self, session_id: int):
"""Stop a session by a session_id."""
await self.engine.stop_session(session_id)
async def generate(self, request):
"""entry of qos engine generate for three api."""
if isinstance(request, CompletionRequestQos):
if isinstance(request.prompt, str):
request.prompt = [request.prompt]
generators = []
for i in range(len(request.prompt)):
result_generator = self.engine.generate(
request.prompt[i],
request.session_id + i,
True, # always use stream to enable batching
sequence_start=True,
sequence_end=True,
request_output_len=request.max_tokens
if request.max_tokens else 512,
stop=False,
top_p=request.top_p,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos,
do_preprocess=False)
generators.append(result_generator)
return generators
elif isinstance(request, GenerateRequestQos):
async_engine = self.engine
sequence_start = async_engine.id2step.get(str(request.session_id),
0) == 0
sequence_end = not request.interactive_mode
generation = async_engine.generate(
request.prompt,
request.session_id,
stream_response=True, # always use stream to enable batching
sequence_start=sequence_start,
sequence_end=sequence_end,
request_output_len=request.request_output_len,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos)
return generation
elif isinstance(request, ChatCompletionRequestQos):
# default chat/completions
result_generator = self.engine.generate(
request.messages,
request.session_id,
True, # always use stream to enable batching
sequence_start=True,
sequence_end=True,
request_output_len=request.max_tokens
if request.max_tokens else 512,
stop=request.stop,
top_p=request.top_p,
top_k=request.top_k,
temperature=request.temperature,
repetition_penalty=request.repetition_penalty,
ignore_eos=request.ignore_eos)
return result_generator
return time.sleep(0.01)
async def generate_with_qos(self, request):
"""called by api server for qos generate."""
if not self.is_qos_enabled():
return await self.generate(request)
# push (request,event) to queue
event = asyncio.Event()
request_event = (request, event)
with self.lock:
self.qos_user_group.enqueue(request_event)
await event.wait()
result_generator = await self.generate(request)
# release self.availSlots resources
with self.lock:
if isinstance(request, CompletionRequestQos) and isinstance(
request.prompt, List):
self.availSlots += len(request.prompt)
else:
self.availSlots += 1
# Update number of served requests for each user
with self.stats_lock:
if request.user_id not in self.user_served_reqs:
self.user_served_reqs[request.user_id] = 1
else:
self.user_served_reqs[request.user_id] += 1
return result_generator
def _serve(self):
"""backend thread for dequeue."""
while not self._stop_event.is_set():
if self.availSlots > 0:
with self.lock:
request_event = self.dequeue(self.usage_stats)
if request_event is not None:
# Update usage_stats
user_group = self.qos_user_group.get_user_group(
request_event[0].user_id)
self.usage_stats.update_usage(request_event[0].user_id,
user_group, 100,
int(time.time()))
if isinstance(request_event[0],
CompletionRequestQos) and isinstance(
request_event[0].prompt, List):
self.availSlots -= len(request_event[0].prompt)
else:
self.availSlots -= 1
request_event[1].set()
logger.debug(
f'Available slot decrease, now: {self.availSlots}')
time.sleep(0)
def _dump_stats(self):
"""dump usage states for debugs."""
ts = 0
while not self._stop_event.is_set():
outdata = ''
with self.stats_lock:
if not self.user_served_reqs:
outdata = 'none'
else:
sorted_uids = sorted(self.user_served_reqs.keys())
for uid in sorted_uids:
outdata += f'{uid} {self.user_served_reqs[uid]} reqs, '
self.user_served_reqs = dict()
logger.info(
f'qos svc running for {ts} seconds,last 20 seconds: {outdata}')
ts += 20
time.sleep(20)
def dequeue(self, usage_stats):
"""dequeue from multiqueue."""
return self.qos_user_group.dequeue(usage_stats)
class QosGroupQueue:
"""create groups for qos outer group schedule."""
def __init__(self, qos_config):
if qos_config is None:
self.user_list = {}
self.queues = {}
else:
self.user_list = qos_config.user_id_maps
self.queues = {}
for user_group in qos_config.user_group_prio:
self.queues[user_group] = UserRequestQueue(
user_group, self.user_list[user_group])
self.user_group_list = list(self.user_list.keys())
self.default_user_group = self.user_group_list[2] if len(
self.user_group_list) >= 3 else 'None'
logger.debug(self.user_list)
logger.debug(self.queues)
logger.debug(self.default_user_group)
def get_user_group(self, user_id):
"""input: user, output user_id"""
for category, users in self.user_list.items():
for user in users:
if user_id == user['id']:
return category
return self.default_user_group
def enqueue(self, request_event):
"""enqueue outer group waiting for schedule."""
user_id = self.get_user_group(request_event[0].user_id)
self.queues[user_id].enqueue(request_event)
def dequeue(self, usage_stats):
"""dequeue outer group schedule."""
for user_group_id, user_group_queue in self.queues.items():
if user_group_queue.empty():
continue
else:
return user_group_queue.dequeue(usage_stats)
return None
# Copyright (c) OpenMMLab. All rights reserved.
import threading
from typing import List
class Buffer:
"""Ring buffer for calculate tokens and reqs usage."""
def __init__(self, ts: int, user_groups: List[str]):
self.ts = ts
# Per user usage
self.uid_to_tokens_ps = dict()
self.uid_to_reqs_ps = dict()
# Per group usage
self.group_to_tokens_ps = dict()
self.group_to_reqs_ps = dict()
for group in user_groups:
self.group_to_tokens_ps[group] = 0
self.group_to_reqs_ps[group] = 0
class UsageStats:
"""calculate usage for qos engine for inner group schedule."""
def __init__(self, total_duration: int, buffer_count: int,
start_index: int, user_groups: List[str]):
self.total_duration = total_duration
self.buffer_count = buffer_count
self.start_index = start_index
self.start_ts = int(0)
self.buffer_duration = int(total_duration / buffer_count)
self.circular_buffer = [
Buffer(self.buffer_duration * i, user_groups)
for i in range(buffer_count)
]
self.user_groups = user_groups
self.lock = threading.Lock()
def update_usage(self, uid: str, group: str, out_token_len: int,
req_ts: int):
"""Update UsageStats when a request is returned."""
with self.lock:
intervals = int((req_ts - self.start_ts) / self.buffer_duration)
curr_idx = (self.start_index + intervals) % self.buffer_count
curr_ts = self.start_ts + intervals * self.buffer_duration
# Current request outside the sliding window
if intervals >= self.buffer_count:
reset_buf_cnt = intervals - self.buffer_count
curr_buf_ts = 0
if reset_buf_cnt >= self.buffer_count:
# All buffers are reset
for i in range(1, self.buffer_count):
reset_idx = (curr_idx + i) % self.buffer_count
self.circular_buffer[reset_idx] = Buffer(
req_ts + i * self.buffer_duration,
self.user_groups)
# Update self.start_index
self.start_index = curr_idx
self.start_ts = req_ts
curr_buf_ts = req_ts
else:
# buffers between self.start_index and curr_idx are reset
for i in range(reset_buf_cnt):
reset_idx = (self.start_index + i) % self.buffer_count
reset_ts = self.circular_buffer[
reset_idx].ts + self.total_duration
self.circular_buffer[reset_idx] = Buffer(
reset_ts, self.user_groups)
# Update self.start_index
self.start_index = (curr_idx + 1) % self.buffer_count
self.start_ts = self.circular_buffer[self.start_index].ts
curr_buf_ts = self.circular_buffer[
curr_idx].ts + self.total_duration
# Set corresponding buffer
self.circular_buffer[curr_idx] = Buffer(
curr_buf_ts, self.user_groups)
self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] = 1
self.circular_buffer[curr_idx].uid_to_tokens_ps[
uid] = out_token_len
self.circular_buffer[curr_idx].group_to_reqs_ps[group] = 1
self.circular_buffer[curr_idx].group_to_tokens_ps[
group] = out_token_len
# Otherwise update corresponding buffer
else:
self.circular_buffer[curr_idx].ts = curr_ts
if uid in self.circular_buffer[curr_idx].uid_to_reqs_ps:
self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] += 1
else:
self.circular_buffer[curr_idx].uid_to_reqs_ps[uid] = 1
if uid in self.circular_buffer[curr_idx].uid_to_tokens_ps:
self.circular_buffer[curr_idx].uid_to_tokens_ps[
uid] += out_token_len
else:
self.circular_buffer[curr_idx].uid_to_tokens_ps[
uid] = out_token_len
self.circular_buffer[curr_idx].group_to_reqs_ps[group] += 1
self.circular_buffer[curr_idx].group_to_tokens_ps[
group] += out_token_len
def get_user_usage(self, uid: str, group: str):
"""Calculate usage stats of the given user and group."""
user_req_usage = 0
user_token_usage = 0
group_req_usage = 0
group_token_usage = 0
# TODO: use reader lock
with self.lock:
for i in range(self.buffer_count):
if uid in self.circular_buffer[i].uid_to_reqs_ps:
user_req_usage += self.circular_buffer[i].uid_to_reqs_ps[
uid]
user_token_usage += self.circular_buffer[
i].uid_to_tokens_ps[uid]
group_req_usage += self.circular_buffer[i].group_to_reqs_ps[
group]
group_token_usage += self.circular_buffer[
i].group_to_tokens_ps[group]
return (user_req_usage, user_token_usage, group_req_usage,
group_token_usage)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Union
import numpy as np
from lmdeploy.serve.async_engine import AsyncEngine
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX, IMAGE_TOKEN
from lmdeploy.vl.engine import ImageEncoder
from lmdeploy.vl.templates import VLPromptType, get_vl_prompt_template
class VLAsyncEngine(AsyncEngine):
"""Visual Language Async inference engine."""
def __init__(self, model_path: str, **kwargs) -> None:
super().__init__(model_path, **kwargs)
if self.model_name == 'base':
raise RuntimeError(
'please specify chat template as guided in https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html#set-chat-template' # noqa: E501
)
self.vl_encoder = ImageEncoder(model_path)
self.vl_prompt_template = get_vl_prompt_template(
model_path, self.chat_template, self.model_name)
def _convert_prompts(self,
prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]]):
"""convert prompts to openai format."""
if isinstance(prompts, str) or isinstance(prompts, tuple):
_prompts = self.vl_prompt_template.prompt_to_messages(prompts)
elif isinstance(prompts[0], tuple) or isinstance(prompts[0], str):
_prompts = [
self.vl_prompt_template.prompt_to_messages(x) for x in prompts
]
else:
_prompts = prompts
return _prompts
async def _get_prompt_input(self, prompt: Dict, do_preprocess: bool,
sequence_start: bool):
"""get input_ids, embeddings and offsets."""
if do_preprocess:
decorated = self.vl_prompt_template.messages2prompt(
prompt, sequence_start)
else:
decorated = prompt
segs = decorated.split(IMAGE_TOKEN)
results = {}
input_ids = []
if len(segs) > 1:
images = await self.vl_prompt_template.async_collect_pil_images(
prompt)
features = await self.vl_encoder.async_infer(images)
features = [x.cpu().numpy() for x in features]
input_ids = []
begins = []
ends = []
for i, seg in enumerate(segs):
if i > 0:
image_dim = features[i - 1].shape[0]
begins.append(len(input_ids))
ends.append(begins[-1] + image_dim)
input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim)
seg_ids = self.tokenizer.encode(seg,
add_bos=((i == 0)
and sequence_start))
input_ids.extend(seg_ids)
ranges = np.stack([begins, ends], axis=1).tolist()
results['input_embeddings'] = features
results['input_embedding_ranges'] = ranges
else:
input_ids = self.tokenizer.encode(decorated,
add_bos=sequence_start)
results['input_ids'] = input_ids
results['prompt'] = decorated
return results
def batch_infer(self, prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]],
**kwargs):
"""Inference a batch of prompts."""
prompts = self._convert_prompts(prompts)
return super().batch_infer(prompts, **kwargs)
def stream_infer(self, prompts: Union[VLPromptType, List[Dict],
List[VLPromptType],
List[List[Dict]]], **kwargs):
"""Inference a batch of prompts with stream mode."""
prompts = self._convert_prompts(prompts)
return super().stream_infer(prompts, **kwargs)
def __call__(self, prompts: Union[VLPromptType, List[Dict],
List[VLPromptType], List[List[Dict]]],
**kwargs):
"""Inference a batch of prompts."""
prompts = self._convert_prompts(prompts)
return super().__call__(prompts, **kwargs)
def chat(self, prompts: VLPromptType, **kwargs):
"""chat."""
_prompts = self._convert_prompts(prompts)
sess = super().chat(_prompts, **kwargs)
# recover prompts & history
sess._prompt = prompts
last_round = sess.history[-1]
sess.history[-1] = (prompts, last_round[-1])
return sess
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from .base import INPUT_MODELS
from .llama import LlamaModel, LlamaReader
from .llama_awq import ensure_fp16orint32
class InternLM2Reader(LlamaReader):
"""InternLM2 model reader."""
attn_layer_patten = r'model.layers.([0-9]+).'
tok_embeddings_key = 'model.tok_embeddings.weight'
norm_weight_key = 'model.norm.weight'
output_weight_key = 'output.weight'
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool,
model_cfg: dict):
super().__init__(new_params, unused_params, last_bin, model_cfg)
def _attn(self, i: int, kind: str, size_dim: int, dim: int = 0):
"""Get q, k, v, o kind for layer i."""
kv_head_num = self.model_cfg['kv_head_num']
gs = int(self.model_cfg['attn_head_num'] / kv_head_num)
qkv = self.params[f'model.layers.{i}.attention.wqkv.{kind}']
qkv = qkv.view(kv_head_num, gs + 2, 128, -1)
hidden_dim = qkv.shape[-1]
q, k, v = torch.split(qkv, [gs, 1, 1], dim=1)
q = q.reshape(-1, hidden_dim)
k = k.reshape(-1, hidden_dim)
v = v.reshape(-1, hidden_dim)
o = self.params.get(f'model.layers.{i}.attention.wo.{kind}')
return q, k, v, o
def attn(self, i: int):
"""Get q, k, v, o weight for layer i."""
return self._attn(i, 'weight', 0, 0)
def attn_bias(self, i: int):
return (None, ) * 4
def attn_zero(self, i: int):
"""Get q, k, v, o zero point for layer i."""
return (None, ) * 4
def attn_scale(self, i: int):
"""Get q, k, v, o scale for layer i."""
return (None, ) * 4
def attn_norm(self, i: int):
"""Get attn norm for layer i."""
return self.params[f'model.layers.{i}.attention_norm.weight']
def _ffn(self, i: int, kind: str):
"""Get ffn kind for layer i."""
result = []
for key in ['w1', 'w2', 'w3']:
tensor = self.params[f'model.layers.{i}.feed_forward.{key}.{kind}']
result.append(tensor)
return (*result, )
def ffn(self, i: int):
"""Get ffn weight for layer i."""
return self._ffn(i, 'weight')
def ffn_zero(self, i: int):
"""Get ffn zero point for layer i."""
return (None, ) * 3
def ffn_scale(self, i: int):
"""Get ffn scale for layer i."""
return (None, ) * 3
def ffn_norm(self, i: int):
"""Get ffn norm for layer i."""
return self.params[f'model.layers.{i}.ffn_norm.weight']
@INPUT_MODELS.register_module(name='internlm2')
class InternLM2Model(LlamaModel):
"""InternLM2 model in hf format."""
Reader = InternLM2Reader
def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
super().__init__(model_path, tokenizer_path, **kwargs)
class InternLM2AwqReader(InternLM2Reader):
"""read weights from internlm2 awq model."""
def __init__(self, new_params: dict, unused_params: dict, last_bin: bool,
model_cfg: dict):
super().__init__(new_params, unused_params, last_bin, model_cfg)
def _attn(self, i: int, kind: str):
"""Get q, k, v, o qweight for layer i."""
kv_head_num = self.model_cfg['kv_head_num']
gs = int(self.model_cfg['attn_head_num'] / kv_head_num)
qkv = self.params[f'model.layers.{i}.attention.wqkv.{kind}']
hidden_dim = qkv.shape[0]
qkv = qkv.view(hidden_dim, kv_head_num, gs + 2, -1)
q, k, v = torch.split(qkv, [gs, 1, 1], dim=-2)
q = q.reshape(hidden_dim, -1)
k = k.reshape(hidden_dim, -1)
v = v.reshape(hidden_dim, -1)
o = self.params.get(f'model.layers.{i}.attention.wo.{kind}')
return ensure_fp16orint32((q, k, v, o))
def attn(self, i: int):
"""Get q, k, v, o qweight for layer i."""
return self._attn(i, 'qweight')
def attn_zero(self, i: int):
"""Get q, k, v, o qzeros for layer i."""
return self._attn(i, 'qzeros')
def attn_scale(self, i: int):
"""Get q, k, v, o scales for layer i."""
return self._attn(i, 'scales')
def ffn(self, i: int):
"""Get ffn qweight for layer i."""
return ensure_fp16orint32(self._ffn(i, 'qweight'))
def ffn_zero(self, i: int):
"""Get ffn qzeros for layer i."""
return ensure_fp16orint32(self._ffn(i, 'qzeros'))
def ffn_scale(self, i: int):
"""Get ffn scales for layer i."""
return ensure_fp16orint32(self._ffn(i, 'scales'))
@INPUT_MODELS.register_module(name='internlm2-awq')
class InternLM2AwqModel(InternLM2Model):
"""InternLM2 awq model."""
Reader = InternLM2AwqReader
def __init__(self, model_path: str, tokenizer_path: str, **kwargs):
super().__init__(model_path, tokenizer_path, **kwargs)
# Copyright (c) OpenMMLab. All rights reserved.
from transformers import AutoConfig
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
_SUPPORTED_ARCHS = dict(
# baichuan-7b
BaiChuanForCausalLM=True,
# baichuan2-7b, baichuan-13b, baichuan2-13b
BaichuanForCausalLM=True,
# chatglm2-6b, chatglm3-6b
ChatGLMModel=False,
# deepseek-moe
DeepseekForCausalLM=False,
# falcon-7b
FalconForCausalLM=False,
# gemma-7b
GemmaForCausalLM=False,
# internlm
InternLMForCausalLM=True,
# internlm2
InternLM2ForCausalLM=True,
# internlm-xcomposer
InternLMXComposerForCausalLM=True,
# internlm2-xcomposer
InternLM2XComposerForCausalLM=False,
# llama, llama2, alpaca, vicuna, codellama, ultracm, yi,
# deepseek-coder, deepseek-llm
LlamaForCausalLM=True,
# Mistral-7B
MistralForCausalLM=False,
# Mixtral-8x7B
MixtralForCausalLM=False,
# Qwen 7B-72B, Qwen-VL-7B
QWenLMHeadModel=True,
# Qwen1.5 7B-72B
Qwen2ForCausalLM=False,
# llava
LlavaLlamaForCausalLM=True)
def is_supported(model_path: str):
"""Check whether supported by turbomind engine.
Args:
model_path (str): the path of a model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
Returns:
support_by_turbomind (bool): Whether input model is supported by turbomind engine
""" # noqa: E501
import os
support_by_turbomind = False
triton_model_path = os.path.join(model_path, 'triton_models')
if os.path.exists(triton_model_path):
support_by_turbomind = True
else:
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if hasattr(cfg, 'architectures'):
arch = cfg.architectures[0]
elif hasattr(cfg,
'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map:
arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1]
else:
raise RuntimeError(
f'Could not find model architecture from config: {cfg}')
if arch in _SUPPORTED_ARCHS:
support_by_turbomind = _SUPPORTED_ARCHS[arch]
# special cases
if arch == 'BaichuanForCausalLM':
num_attn_head = cfg.num_attention_heads
if num_attn_head == 40:
# baichuan-13B, baichuan2-13B not supported by turbomind
support_by_turbomind = False
return support_by_turbomind
# Copyright (c) OpenMMLab. All rights reserved.
from .utils import load_image
__all__ = ['load_image']
# Copyright (c) OpenMMLab. All rights reserved.
IMAGE_DUMMY_TOKEN_INDEX = 0
IMAGE_TOKEN = '<IMAGE_TOKEN>'
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import queue
import time
from threading import Thread
from typing import List, Union
from PIL.Image import Image
from lmdeploy.utils import get_logger
from lmdeploy.vl.model.builder import load_vl_model
logger = get_logger('lmdeploy')
class Record:
"""Batching manager."""
def __init__(self):
self.number = []
self.waiting = []
self.done = []
self.res_que = []
self.total = 0
def enqueue(self, images: List[Image], que: Union[queue.Queue,
asyncio.Queue]):
"""add ith request to manager."""
self.number.append(len(images))
self.waiting.extend(images)
self.res_que.append(que)
self.total += len(images)
self.log('received', len(images))
def dequeue(self, max_batch_size):
"""try to dequeue max batch size images."""
inputs = self.waiting[:max_batch_size]
self.waiting = self.waiting[max_batch_size:]
self.total -= len(inputs)
self.log('process', len(inputs))
return inputs
def nofify(self):
"""set result if request i is finished."""
if len(self.number) == 0 or self.number[0] > len(self.done):
return False
num_images = self.number.pop(0)
outputs = self.done[:num_images]
self.done = self.done[num_images:]
que = self.res_que.pop(0)
if isinstance(que, queue.Queue):
que.put(outputs)
else:
que._loop.call_soon_threadsafe(que.put_nowait, outputs)
self.log('done', num_images)
return True
def log(self, task: str, num: int):
logger.info(f'ImageEncoder {task} {num} images, '
f'left {self.total} images.')
class ImageEncoder:
"""Image encoder."""
def __init__(self, model_path: str, max_batch_size: int = 16):
self.model = load_vl_model(model_path)
self.max_batch_size = max_batch_size
self.loop = asyncio.new_event_loop()
self.work_thread = self._start_work_thread()
def _start_work_thread(self):
"""internal thread."""
def _work_thread():
asyncio.set_event_loop(self.loop)
self.que = asyncio.Queue()
self.loop.run_until_complete(self._forward_loop())
thread = Thread(target=_work_thread, daemon=True)
thread.start()
return thread
async def _forward_loop(self):
"""working loop to process images."""
logger.info('start ImageEncoder._forward_loop')
record = Record()
while True:
while record.total == 0 or (self.que.qsize() and
record.total < self.max_batch_size):
item = await self.que.get()
record.enqueue(item[0], item[1])
inputs = record.dequeue(self.max_batch_size)
outputs = self.forward(inputs)
record.done.extend(outputs)
while record.nofify():
pass
def forward(self, inputs: List[Image]):
"""Model forward."""
time_start = time.perf_counter()
outputs = self.model.forward(inputs)
time_end = time.perf_counter()
logger.info(f'ImageEncoder forward {len(inputs)} images, '
f'cost {time_end - time_start:.3f}s')
return outputs
def infer(self, inputs: List[Image]):
"""infer."""
outputs = queue.Queue()
item = (inputs, outputs)
self.loop.call_soon_threadsafe(self.que.put_nowait, item)
results = outputs.get()
return results
async def async_infer(self, inputs: List[Image]):
"""async infer."""
outputs = asyncio.Queue()
item = (inputs, outputs)
self.loop.call_soon_threadsafe(self.que.put_nowait, item)
results = await outputs.get()
return results
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABC, abstractmethod
from typing import List
import PIL
import torch
class VisonModel(ABC):
"""Visual model which extract image feature."""
@abstractmethod
def forward(self, images: List[PIL.Image.Image]) -> List[torch.Tensor]:
"""extract image feature.
Args:
images (List[PIL.Image.Image]): input images
Return:
List[torch.Tensor]: extract image feature for each input image
"""
raise NotImplementedError()
# Copyright (c) OpenMMLab. All rights reserved.
import os
from lmdeploy.utils import get_hf_config_content, get_model
from .llava import LlavaVisionModel
from .qwen import QwenVisionModel
from .yi import YiVisionModel
def load_vl_model(model_path: str):
"""load visual model."""
if not os.path.exists(model_path):
model_path = get_model(model_path)
config = get_hf_config_content(model_path)
arch = config['architectures'][0]
if arch == 'QWenLMHeadModel':
return QwenVisionModel(model_path)
elif arch == 'LlavaLlamaForCausalLM':
projector_type = config.get('mm_projector_type', 'linear')
if '_Norm' in projector_type:
return YiVisionModel(model_path)
else:
return LlavaVisionModel(model_path)
raise ValueError(f'unsupported vl model with arch {arch}')
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from
# https://github.com/haotian-liu/LLaVA.git
import warnings
from typing import List, Union
import torch
from PIL.Image import Image
from lmdeploy.utils import get_logger
from lmdeploy.vl.model.base import VisonModel
from lmdeploy.vl.model.utils import load_model_from_weight_files
logger = get_logger('lmdeploy')
def check_llava_install():
"""check llava install."""
try:
import llava # noqa: F401
except ImportError:
raise ImportError(
'To use LlavaVLModel, please install llava by '
'pip install git+https://github.com/haotian-liu/LLaVA.git')
class LlavaVisionModel(VisonModel):
"""Llava visual model."""
def __init__(self, model_path, device='cuda'):
self.model_path = model_path
self.device = device
self.build_model()
def build_model(self):
"""build model & load weights."""
# check llava install
check_llava_install()
# currently, only support llava llama
from llava.model.language_model.llava_llama import (
LlavaConfig, LlavaLlamaForCausalLM)
self.config = LlavaConfig.from_pretrained(self.model_path)
assert self.config.model_type in ['llava', 'llava_llama'], \
'currently, only support llava llama'
# empty init
with torch.device('meta'), warnings.catch_warnings():
warnings.simplefilter('ignore')
model = LlavaLlamaForCausalLM.from_pretrained(self.model_path)
del model.lm_head
del model.model.embed_tokens
del model.model.layers
del model.model.norm
# # load weight
with torch.device(self.device):
model.to_empty(device=self.device)
vision_tower = model.get_vision_tower()
vision_tower.is_loaded = False
vision_tower.load_model()
load_model_from_weight_files(model, self.model_path)
model.eval().half()
self.model = model.model
self.vision_tower = model.model.vision_tower
self.mm_projector = model.model.mm_projector
def encode_images(self, images: torch.Tensor) -> torch.Tensor:
"""encode images."""
image_features = self.vision_tower(images)
image_features = self.mm_projector(image_features)
return image_features
def preprocess(
self,
images: List[Image]) -> Union[torch.Tensor, List[torch.Tensor]]:
"""preprocess."""
# TODO: gpu processor
from llava.mm_utils import process_images
images = [x.convert('RGB') for x in images]
image_processor = self.vision_tower.image_processor
outputs = process_images(images, image_processor, self.config)
return outputs
@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
from llava.model.llava_arch import (get_anyres_image_grid_shape,
unpad_image)
image_sizes = [x.size for x in images]
images = self.preprocess(images)
if isinstance(images, list):
images = [x.to(self.device, dtype=torch.float16) for x in images]
else:
images = images.to(self.device, dtype=torch.float16)
if type(images) is list or images.ndim == 5:
if type(images) is list:
images = [x.unsqueeze(0) if x.ndim == 3 else x for x in images]
concat_images = torch.cat([image for image in images], dim=0)
image_features = self.encode_images(concat_images)
split_sizes = [image.shape[0] for image in images]
image_features = torch.split(image_features, split_sizes, dim=0)
mm_patch_merge_type = getattr(self.config, 'mm_patch_merge_type',
'flat')
image_aspect_ratio = getattr(self.config, 'image_aspect_ratio',
'square')
if mm_patch_merge_type == 'flat':
image_features = [x.flatten(0, 1) for x in image_features]
elif mm_patch_merge_type.startswith('spatial'):
new_image_features = []
for image_idx, image_feature in enumerate(image_features):
if image_feature.shape[0] > 1:
base_image_feature = image_feature[0]
image_feature = image_feature[1:]
height = width = self.vision_tower.num_patches_per_side
assert height * width == base_image_feature.shape[0]
if image_aspect_ratio == 'anyres':
num_patch_width, num_patch_height = \
get_anyres_image_grid_shape(
image_sizes[image_idx],
self.config.image_grid_pinpoints,
self.vision_tower.config.image_size)
image_feature = image_feature.view(
num_patch_height, num_patch_width, height,
width, -1)
else:
raise NotImplementedError
if 'unpad' in mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3).contiguous()
image_feature = image_feature.flatten(1,
2).flatten(
2, 3)
image_feature = unpad_image(
image_feature, image_sizes[image_idx])
image_feature = torch.cat((
image_feature,
self.model.image_newline[:, None, None].expand(
*image_feature.shape[:-1], 1).to(
image_feature.device)),
dim=-1)
image_feature = image_feature.flatten(1,
2).transpose(
0, 1)
else:
image_feature = image_feature.permute(
0, 2, 1, 3, 4).contiguous()
image_feature = image_feature.flatten(0, 3)
image_feature = torch.cat(
(base_image_feature, image_feature), dim=0)
else:
image_feature = image_feature[0]
if 'unpad' in mm_patch_merge_type:
image_feature = torch.cat(
(image_feature,
self.model.image_newline[None].to(
image_feature.device)),
dim=0)
new_image_features.append(image_feature)
image_features = new_image_features
else:
raise ValueError('Unexpected mm_patch_merge_type: '
f'{self.config.mm_patch_merge_type}')
else:
image_features = self.encode_images(images)
image_features = [x for x in image_features]
return image_features
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from accelerate import init_empty_weights
from PIL.Image import Image
from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.vl.model.base import VisonModel
from lmdeploy.vl.model.utils import load_model_from_weight_files
class QwenVisionModel(VisonModel):
"""Qwen vision model."""
def __init__(self, model_path, device='cuda'):
self.model_path = model_path
self.device = device
self.build_model()
def build_model(self):
with init_empty_weights():
config = AutoConfig.from_pretrained(self.model_path,
trust_remote_code=True)
model = AutoModelForCausalLM.from_config(config,
trust_remote_code=True)
del model.lm_head
for key in ['wte', 'h', 'ln_f']:
setattr(model.transformer, key, None)
with torch.device(self.device):
model.to_empty(device=self.device)
load_model_from_weight_files(model, self.model_path)
self.model = model.transformer.visual
self.model.eval().half()
@torch.no_grad()
def forward(self, images: List[Image]) -> List[torch.Tensor]:
"""forward."""
outputs = [x.convert('RGB') for x in images]
outputs = [self.model.image_transform(x) for x in outputs]
outputs = torch.stack(outputs, dim=0)
outputs = self.model(outputs)
outputs = torch.split(outputs, 1, dim=0)
outputs = [x.squeeze() for x in outputs]
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Dict, List
import torch
import torch.nn as nn
from safetensors.torch import load_file
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
from transformers.utils.hub import get_checkpoint_shard_files
def load_weight_ckpt(ckpt: str) -> Dict[str, torch.Tensor]:
"""load checkpoint."""
if ckpt.endswith('.safetensors'):
return load_file(ckpt)
else:
return torch.load(ckpt)
def get_used_weight_files(folder: str,
state_dict: Dict[str, torch.Tensor]) -> List[str]:
"""get used checkpoint which contains keys in state_dict."""
_index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
_safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
if os.path.exists(_index_file):
index_file = _index_file
elif os.path.exists(_safe_index_file):
index_file = _safe_index_file
else:
raise FileNotFoundError
_, sharded_metadata = get_checkpoint_shard_files(folder, index_file)
potential_keys = set(state_dict.keys())
supplied_keys = set(sharded_metadata['weight_map'].keys())
shared_keys = potential_keys & supplied_keys
valid_files = set(sharded_metadata['weight_map'][k] for k in shared_keys)
return valid_files
def load_model_from_weight_files(model: nn.Module, folder: str) -> None:
"""load nn.Module weight from folder."""
valid_files = get_used_weight_files(folder, model.state_dict())
for file_name in valid_files:
ckpt = os.path.join(folder, file_name)
state_dict = load_weight_ckpt(ckpt)
model.load_state_dict(state_dict, strict=False)
# Copyright (c) OpenMMLab. All rights reserved.
import os
from contextlib import contextmanager
from typing import MutableSequence
import torch.nn as nn
from lmdeploy.vl.model.llava import LlavaVisionModel, check_llava_install
_model_path = None
def _build_vision_projector(config, delay_load=False, **kwargs):
"""build yi projector."""
# copy from https://github.com/01-ai/Yi/blob/main/VL/llava/model/multimodal_projector/builder.py # noqa: E501
projector_type = getattr(config, 'mm_projector_type', 'linear')
if projector_type == 'linear':
return nn.Linear(config.mm_hidden_size, config.hidden_size)
import re
use_norm = False
if '_Norm' in projector_type:
use_norm = True
projector_type = projector_type.replace('_Norm', '')
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
if mlp_gelu_match:
mlp_depth = int(mlp_gelu_match.group(1))
if use_norm:
modules = [
nn.Linear(config.mm_hidden_size, config.hidden_size),
nn.LayerNorm(config.hidden_size),
]
else:
modules = [nn.Linear(config.mm_hidden_size, config.hidden_size)]
for _ in range(1, mlp_depth):
modules.append(nn.GELU())
if use_norm:
modules.append(
nn.Linear(config.hidden_size, config.hidden_size))
modules.append(nn.LayerNorm(config.hidden_size))
else:
modules.append(
nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
if projector_type == 'identity':
return nn.Identity()
raise ValueError(f'Unknown projector type: {projector_type}')
def _build_vision_tower(vision_tower_cfg, **kwargs):
"""build yi vision tower."""
cfg = vision_tower_cfg
vision_tower = getattr(cfg, 'mm_vision_tower',
getattr(cfg, 'vision_tower', None))
if os.path.exists(os.path.join(_model_path, vision_tower)):
vision_tower = os.path.join(_model_path, vision_tower)
from llava.model.multimodal_encoder.clip_encoder import CLIPVisionTower
is_absolute_path_exists = os.path.exists(vision_tower)
if is_absolute_path_exists or vision_tower.startswith(
'openai') or vision_tower.startswith(
'laion') or 'ShareGPT4V' in vision_tower:
return CLIPVisionTower(vision_tower, args=vision_tower_cfg, **kwargs)
raise ValueError(f'Unknown vision tower: {vision_tower}')
def _set_function(old_func, new_func):
import gc
refs = gc.get_referrers(old_func)
obj_id = id(old_func)
for ref in refs:
if isinstance(ref, dict):
for x, y in ref.items():
if id(y) == obj_id:
ref[x] = new_func
elif isinstance(ref, MutableSequence):
for i, v in enumerate(ref):
if id(v) == obj_id:
ref[i] = new_func
@contextmanager
def init_yi_model():
import llava # noqa: F401
old_projector = eval(
'llava.model.multimodal_projector.builder.build_vision_projector')
_set_function(old_projector, _build_vision_projector)
old_vision_tower = eval(
'llava.model.multimodal_encoder.builder.build_vision_tower')
_set_function(old_vision_tower, _build_vision_tower)
yield
_set_function(_build_vision_projector, old_projector)
_set_function(_build_vision_tower, old_vision_tower)
@contextmanager
def disable_transformers_logging():
import transformers
from transformers.utils import logging
previous_level = logging.get_verbosity()
logging.set_verbosity(transformers.logging.ERROR)
yield
logging.set_verbosity(previous_level)
class YiVisionModel(LlavaVisionModel):
"""Yi visual model."""
def __init__(self, model_path, device='cuda'):
self.model_path = model_path
self.device = device
self.build_model()
def build_model(self):
"""build model & load weights."""
check_llava_install()
global _model_path
_model_path = self.model_path
with init_yi_model(), disable_transformers_logging():
super().build_model()
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
from typing import Dict, List, Tuple, Union
import PIL
from lmdeploy.model import BaseModel
from lmdeploy.utils import get_hf_config_content
from lmdeploy.vl.constants import IMAGE_TOKEN
from lmdeploy.vl.utils import encode_image_base64, load_image
VLPromptType = Union[str, Tuple[str, PIL.Image.Image],
Tuple[str, List[PIL.Image.Image]]]
class VLChatTemplateWrapper:
"""vl chat template wrapper."""
def __init__(self, chat_template: BaseModel):
self.chat_template = chat_template
def prompt_to_messages(self, prompt: VLPromptType):
"""convert prompt to GTP4V format."""
messages = {
'role': 'user',
'content': [{
'type': 'text',
'text': '',
}]
}
if isinstance(prompt, str):
messages['content'][0]['text'] = prompt
else:
prompt, images = prompt
if not isinstance(images, list):
images = [images]
messages['content'][0]['text'] = prompt
for image in images:
if isinstance(image, str):
image = load_image(image)
image_base64_data = encode_image_base64(image)
item = {
'type': 'image_url',
'image_url': {
'url': f'data:image/jpeg;base64,{image_base64_data}'
}
}
messages['content'].append(item)
return [messages]
async def async_collect_pil_images(
self, messages: Dict) -> List[PIL.Image.Image]:
"""collect image from messages."""
images = []
for message in messages:
role = message['role']
content = message['content']
if role != 'user' or isinstance(content, str):
continue
for item in content:
if item['type'] != 'image_url':
continue
url = item['image_url']['url']
images.append(url)
def _inner_call(i, images):
url = images[i]
images[i] = load_image(url)
await asyncio.gather(*[
asyncio.get_event_loop().run_in_executor(
None, _inner_call, i, images) for i in range(len(images))
])
return images
def append_image_token(self, prompt, num_images: int):
"""append image token to user prompt."""
return IMAGE_TOKEN * num_images + '\n' + prompt
def convert_messages(self, messages, sequence_start=True):
"""convert GPT4V message format to GPT4 text format."""
new_messages = []
for message in messages:
role = message['role']
content = message['content']
if role != 'user' or isinstance(content, str):
new_messages.append(message)
continue
num_images = 0
for item in content:
if item['type'] == 'image_url':
num_images += 1
elif item['type'] == 'text':
prompt = item['text']
new_item = {
'role': 'user',
'content': self.append_image_token(prompt, num_images)
}
new_messages.append(new_item)
return new_messages
def messages2prompt(self, messages, sequence_start=True) -> str:
"""convert messages to decorated prompt."""
if isinstance(messages, str):
return self.chat_template.messages2prompt(messages, sequence_start)
new_messages = self.convert_messages(messages, sequence_start)
return self.chat_template.messages2prompt(new_messages, sequence_start)
class LlavaVLChatTemplateWrapper(VLChatTemplateWrapper):
"""Llava vl chat template."""
pass
class YiVLChatTemplateWrapper(VLChatTemplateWrapper):
"""Yi vl chat template."""
pass
class QwenVLChatTemplateWrapper(VLChatTemplateWrapper):
"""Qwen vl chat template."""
def append_image_token(self, prompt, num_images: int):
"""append image tokens to user prompt."""
res = ''
for i in range(num_images):
res += f'Picture {str(i)}:{IMAGE_TOKEN}\n'
res = res + prompt
return res
def get_vl_prompt_template(model_path: str, chat_template: BaseModel,
model_name: str) -> VLChatTemplateWrapper:
"""get vision language prompt template."""
if model_name == 'yi-vl':
return YiVLChatTemplateWrapper(chat_template)
config = get_hf_config_content(model_path)
arch = config['architectures'][0]
if arch == 'QWenLMHeadModel':
return QwenVLChatTemplateWrapper(chat_template)
elif arch == 'LlavaLlamaForCausalLM':
return LlavaVLChatTemplateWrapper(chat_template)
raise ValueError(f'unsupported vl_prompt_template with arch {arch}')
# Copyright (c) OpenMMLab. All rights reserved.
import base64
from io import BytesIO
from typing import Union
import requests
from PIL import Image
def encode_image_base64(image: Image.Image) -> str:
"""encode image to base64 format."""
buffered = BytesIO()
image.save(buffered, format='PNG')
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def load_image_from_base64(image: Union[bytes, str]) -> Image.Image:
"""load image from base64 format."""
return Image.open(BytesIO(base64.b64decode(image)))
def load_image(image_url: str) -> Image.Image:
"""load image from url, local path or openai GPT4V."""
headers = {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
'(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'
}
if image_url.startswith('http'):
response = requests.get(image_url, headers=headers)
response.raise_for_status()
# Open the image using PIL
img = Image.open(BytesIO(response.content))
elif image_url.startswith('data:image'):
img = load_image_from_base64(image_url.split(',')[1])
else:
img = Image.open(image_url)
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment