Commit fe851fbc authored by zhouxiang's avatar zhouxiang
Browse files

0.2.6版本新增文件补充

parent e2d98ddc
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed._tensor import DeviceMesh
from ..dist_utils import (colwise_parallelize_linear_fn,
rowwise_parallelize_linear_fn)
from ..kernels import apply_rotary_pos_emb, fill_kv_cache, paged_attention_fwd
class PatchedQwen2Attention(nn.Module):
@classmethod
def _distribute_partition_fn(cls, mod_name: str, mod: nn.Module,
device_mesh: DeviceMesh):
"""Distribution partition callback."""
if mod_name in ['q_proj', 'k_proj', 'v_proj']:
colwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
elif mod_name in ['o_proj']:
rowwise_parallelize_linear_fn(mod,
device_mesh=device_mesh,
to_local=True)
@classmethod
def _distribute_output_fn(cls, outputs, device_mesh: DeviceMesh):
"""Distribution output hook."""
dist.all_reduce(outputs[0])
return outputs
def _contiguous_batching_forward_impl(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
world_size: int = 1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite implementation of forward.
Add continuous batching support. Add paged attention support. TP
support.
"""
context = self.context.context
kv_seq_length = context.kv_seq_length
q_seq_length = context.q_seq_length
q_start_loc = context.q_start_loc
block_offsets = context.block_offsets
max_q_seq_length = context.max_q_seq_length
max_kv_seq_length = context.max_kv_seq_length
num_heads = self.num_heads // world_size
num_kv_heads = self.num_key_value_heads // world_size
head_dim = self.head_dim
hidden_size = num_heads * head_dim
def __qkv_proj(hidden_states):
"""qkv proj."""
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
return query_states, key_states, value_states
def __rotary_emb_fn(query_states, key_states, value_states):
if hasattr(self, 'rotary_emb'):
cos, sin = self.rotary_emb(value_states,
seq_len=max_kv_seq_length)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids,
context.position_ids_1d)
return query_states, key_states, value_states
query_states, key_states, value_states = __qkv_proj(hidden_states)
query_states = query_states.view(-1, num_heads, head_dim)
key_states = key_states.view(-1, num_kv_heads, head_dim)
value_states = value_states.view(-1, num_kv_heads, head_dim)
query_states, key_states, value_states = __rotary_emb_fn(
query_states, key_states, value_states)
fill_kv_cache(
key_states,
value_states,
past_key_value[0],
past_key_value[1],
q_start_loc,
q_seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
block_offsets=block_offsets,
)
attn_output = query_states
use_sliding_windows = (getattr(self.config, 'sliding_window', None)
is not None and self.config.use_sliding_window)
window_size = self.config.sliding_window
if not use_sliding_windows:
window_size = -1
paged_attention_fwd(
query_states,
past_key_value[0],
past_key_value[1],
attn_output,
block_offsets,
q_start_loc=q_start_loc,
q_seqlens=q_seq_length,
kv_seqlens=kv_seq_length,
max_seqlen=max_q_seq_length,
window_size=window_size,
)
attn_output = attn_output.reshape(*hidden_states.shape[:-1],
hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
"""Rewrite of forward."""
world_size = 1
if dist.is_initialized():
world_size = dist.get_world_size()
return self._contiguous_batching_forward_impl(
hidden_states,
position_ids,
past_key_value,
world_size=world_size,
)
# Copyright (c) OpenMMLab. All rights reserved.
from .scheduler import Scheduler
__all__ = ['Scheduler']
# Copyright (c) OpenMMLab. All rights reserved.
from ...config import CacheConfig
from .base_block_manager import BaseBlockManager
from .default_block_manager import DefaultBlockManager
from .window_block_manager import WindowBlockManager
def build_block_manager(cache_config: CacheConfig) -> BaseBlockManager:
"""build block manager.
Args:
cache_config (CacheConfig): cache_config.
"""
num_cpu_blocks = cache_config.num_cpu_blocks
num_gpu_blocks = cache_config.num_gpu_blocks
window_size = cache_config.window_size
if window_size < 0:
return DefaultBlockManager(num_gpu_blocks, num_cpu_blocks)
else:
return WindowBlockManager(num_gpu_blocks,
num_cpu_blocks,
window_size=window_size)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Union
import numpy as np
from ...adapter.adapter import SchedulerAdapter
from ...messages import SchedulerSequence
class LogicalMemory:
"""Logical memory blocks."""
def __init__(self, num_blocks: int) -> None:
self._num_blocks = num_blocks
self.phy_map: np.ndarray = np.zeros(self._num_blocks, dtype=np.int64)
def get_physical_blocks(self, logical_address: np.ndarray):
"""get physical address."""
if isinstance(logical_address,
np.ndarray) and len(logical_address) == 0:
return np.empty((0, ), dtype=np.int64)
return self.phy_map[logical_address]
def num_blocks(self):
"""get num blocks."""
return self._num_blocks
class PhysicalMemory:
"""physical memory blocks."""
def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int) -> None:
self._num_cpu_blocks = num_cpu_blocks
self._num_gpu_blocks = num_gpu_blocks
self._num_blocks = num_cpu_blocks + num_gpu_blocks
self.ref_count: np.ndarray = np.zeros((self._num_blocks, ),
dtype=np.int64)
self.swap_count: np.ndarray = np.zeros((self._num_blocks, ),
dtype=np.int64)
def num_cpu_blocks(self):
"""get num cpu blocks."""
return self._num_cpu_blocks
def num_gpu_blocks(self):
"""get num gpu blocks."""
return self._num_gpu_blocks
class PhysicalAllocator:
"""The physical block allocator.
The allocator won't allocate real memory. It is used to support block
manager.
"""
def __init__(self,
memory: PhysicalMemory,
num_blocks: int,
offset: int = 0):
self._mem = memory
self._num_blocks = num_blocks
self._offset = offset
self._free_blocks = np.arange(num_blocks, dtype=np.int64) + offset
self._free_count = num_blocks
def allocate(self, num_blocks: int):
"""Allocate block from block pool."""
if self.get_num_free_blocks() >= num_blocks:
num_used = self._num_blocks - self._free_count
blocks = self._free_blocks[num_used:num_used + num_blocks]
self._mem.ref_count.put(blocks, 1)
self._free_count -= num_blocks
return blocks
else:
raise MemoryError('No enough free memory blocks.')
def free(self, blocks: np.ndarray):
"""Free block to block pool."""
np.add.at(self._mem.ref_count, blocks, -1)
ref_count = self.get_ref_count(blocks)
freed_blocks = blocks[ref_count == 0]
num_freed_blocks = len(freed_blocks)
if num_freed_blocks > 0:
num_used = self._num_blocks - self._free_count
self._free_blocks[num_used -
num_freed_blocks:num_used] = freed_blocks
self._free_count += num_freed_blocks
return freed_blocks
def get_num_free_blocks(self):
"""Get numbers of free blocks."""
return self._free_count
def get_ref_count(self, blocks: np.ndarray):
"""get ref count."""
return self._mem.ref_count[blocks]
class LogicalAllocator:
"""The logical block allocator."""
def __init__(self, num_cpu_blocks: int, num_gpu_blocks: int) -> None:
self._log_mem = LogicalMemory(num_cpu_blocks + num_gpu_blocks)
self._phy_mem = PhysicalMemory(num_cpu_blocks, num_gpu_blocks)
self._cpu_mem_offset = num_gpu_blocks
self._gpu_allocator = PhysicalAllocator(self._phy_mem, num_gpu_blocks,
0)
self._cpu_allocator = PhysicalAllocator(self._phy_mem, num_cpu_blocks,
self._cpu_mem_offset)
num_blocks = self._log_mem.num_blocks()
self._num_blocks = num_blocks
self._free_blocks = np.arange(num_blocks)
self._free_count = num_blocks
def get_phy_allocator(self, device: str):
"""get allocator."""
if device == 'gpu':
return self._gpu_allocator
elif device == 'cpu':
return self._cpu_allocator
else:
raise ValueError(f'Unsupported device: {device}')
def allocate(self, num_blocks: int, device: str = 'gpu'):
"""allocate logical blocks."""
if num_blocks == 0:
return np.empty((0, ), dtype=np.int64)
phy_allocator = self.get_phy_allocator(device)
logical_enable = self.get_num_free_blocks() >= num_blocks
physical_enable = phy_allocator.get_num_free_blocks() >= num_blocks
if logical_enable and physical_enable:
num_used = self._num_blocks - self._free_count
blocks = self._free_blocks[num_used:num_used + num_blocks]
phy_blocks = phy_allocator.allocate(num_blocks)
self._log_mem.phy_map.put(blocks, phy_blocks)
self._free_count -= num_blocks
return blocks.copy()
else:
raise MemoryError('No enough free memory blocks.')
def free(self, blocks: np.ndarray):
"""Free logical block."""
phy_blocks = self.get_physical_blocks(blocks)
cpu_blocks = phy_blocks[phy_blocks >= self._cpu_mem_offset]
gpu_blocks = phy_blocks[phy_blocks < self._cpu_mem_offset]
if len(cpu_blocks) > 0:
self._cpu_allocator.free(cpu_blocks)
if len(gpu_blocks) > 0:
self._gpu_allocator.free(gpu_blocks)
ref_count = self._phy_mem.ref_count[phy_blocks]
freed_blocks = blocks[ref_count == 0]
num_freed_blocks = len(freed_blocks)
if num_freed_blocks > 0:
num_used = self._num_blocks - self._free_count
self._free_blocks[num_used -
num_freed_blocks:num_used] = freed_blocks
self._free_count += num_freed_blocks
def get_num_free_blocks(self):
"""Get numbers of free blocks."""
return self._free_count
def get_physical_blocks(self, blocks: np.ndarray):
"""get physical address."""
return self._log_mem.get_physical_blocks(blocks)
def get_ref_count(self, blocks: np.ndarray):
"""get ref count."""
phy_blocks = self.get_physical_blocks(blocks)
return self._phy_mem.ref_count[phy_blocks]
def add_ref_count(self, blocks: np.ndarray, value: np.ndarray):
"""update ref count."""
phy_blocks = self.get_physical_blocks(blocks)
np.add.at(self._phy_mem.ref_count, phy_blocks, value)
def cpu_mem_offset(self):
"""get cpu mem offset in unified physical memory."""
return self._cpu_mem_offset
def count_cpu_blocks(self, blocks: np.ndarray):
"""count cpu blocks."""
phy_blocks = self.get_physical_blocks(blocks)
return np.count_nonzero(phy_blocks >= self.cpu_mem_offset())
def count_gpu_blocks(self, blocks: np.ndarray):
"""count gpu blocks."""
phy_blocks = self.get_physical_blocks(blocks)
return np.count_nonzero(phy_blocks < self.cpu_mem_offset())
def update_phy_map(self, log_blocks: np.ndarray, phy_blocks: np.ndarray):
"""update physical map."""
assert len(phy_blocks) == len(log_blocks)
self._log_mem.phy_map.put(log_blocks, phy_blocks)
def on_device(self, blocks: np.ndarray, device: str):
"""blocks on given device."""
if len(blocks) == 0:
return False
# TODO: check all blocks
cpu_mem_offset = self.cpu_mem_offset()
phy_blocks = self.get_physical_blocks(blocks[:1])
if phy_blocks[0] < cpu_mem_offset:
phy_device = 'gpu'
else:
phy_device = 'cpu'
return device == phy_device
BlockTable = np.ndarray
class BaseBlockManager:
"""ABC of block manager.
Args:
num_gpu_blocks (int): number of gpu blocks.
num_cpu_blocks (int): number of cpu blocks.
"""
def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None:
self.num_gpu_blocks = num_gpu_blocks
self.num_cpu_blocks = num_cpu_blocks
self.allocator = LogicalAllocator(num_cpu_blocks, num_gpu_blocks)
self.block_tables: Dict[int, BlockTable] = {}
@classmethod
def num_required_blocks(cls, obj: Union[SchedulerSequence,
SchedulerAdapter]):
"""get num required blocks."""
raise NotImplementedError('Not implemented.')
@classmethod
def last_block_size(cls, seq: SchedulerSequence) -> int:
"""get last block size."""
raise NotImplementedError('Not implemented.')
def can_allocate(self, msg: SchedulerSequence):
"""Return if physical block can be allocated for given message."""
raise NotImplementedError('Not implemented.')
def allocate_msg(self, msg: SchedulerSequence):
"""Allocate physical blocks for given message according to logical
blocks."""
raise NotImplementedError('Not implemented.')
def allocate_adapter(self, adapter: SchedulerAdapter):
"""Allocate cpu blocks for given adapter."""
raise NotImplementedError('Not implemented.')
def free(self, msg: SchedulerSequence):
"""Free all physical blocks allocated for the session."""
raise NotImplementedError('Not implemented.')
def can_append_slot(self, msg: SchedulerSequence):
"""Return true if the message can append new slot."""
raise NotImplementedError('Not implemented.')
def append_slot(self, msg: SchedulerSequence):
"""Append new slot to message."""
raise NotImplementedError('Not implemented.')
def can_fork(self, from_msg: SchedulerSequence):
"""Return true if blocks can be folked."""
raise NotImplementedError('Not implemented.')
def fork(self, from_msg: SchedulerSequence, to_msg: SchedulerSequence):
"""Fork new message."""
raise NotImplementedError('Not implemented.')
def try_swap_out(self, msg: Union[SchedulerSequence, SchedulerAdapter]):
"""Try swap msg out."""
raise NotImplementedError('Not implemented.')
def try_swap_in(self, msg: Union[SchedulerSequence, SchedulerAdapter]):
"""Try swap msg in."""
raise NotImplementedError('Not implemented.')
def get_block_table(self, msg: Union[SchedulerSequence, SchedulerAdapter]):
"""Get the block table of given msg.
Args:
msg (SchedulerSequence): The msg to get block table.
"""
logical_blocks = msg.logical_blocks
return self.allocator.get_physical_blocks(
logical_blocks.get_real_blocks())
def allocate(self, data: Union[SchedulerSequence, SchedulerAdapter]):
"""allocate stuff."""
if isinstance(data, SchedulerSequence):
return self.allocate_msg(data)
elif isinstance(data, SchedulerAdapter):
return self.allocate_adapter(data)
else:
raise TypeError(f'Unsupported allocate type: {type(data)}')
def get_num_free_gpu_blocks(self) -> int:
"""Get number of free gpu blocks."""
return self.allocator.get_phy_allocator('gpu').get_num_free_blocks()
def get_num_free_cpu_blocks(self) -> int:
"""Get number of free cpu blocks."""
return self.allocator.get_phy_allocator('cpu').get_num_free_blocks()
def on_device(self, msg: SchedulerSequence, device: str):
allocator = self.allocator
logical_blocks = msg.logical_blocks
return allocator.on_device(logical_blocks.get_real_blocks(), device)
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
from typing import Dict, Union
import numpy as np
from ...adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter
from ...messages import SchedulerSequence
from .base_block_manager import BaseBlockManager
def _div_up(x, n):
"""perform div up."""
return (x + n - 1) // n
BlockTable = np.ndarray
class DefaultBlockManager(BaseBlockManager):
"""Manage the usage of blocks, generate block tables.
Args:
num_gpu_blocks (int): number of gpu blocks.
num_cpu_blocks (int): number of cpu blocks.
"""
@classmethod
def num_required_blocks(cls, obj: Union[SchedulerSequence,
SchedulerAdapter]):
"""get num required blocks."""
if isinstance(obj, SchedulerSequence):
num_tokens = obj.num_all_tokens()
num_all_blocks = _div_up(num_tokens, obj.block_size)
return num_all_blocks - len(obj.logical_blocks)
else:
if obj.is_actived():
return 0
else:
return obj.rank * len(obj.target_modules)
@classmethod
def last_block_size(cls, seq: SchedulerSequence) -> int:
"""get last block size."""
num_blocks = len(seq.logical_blocks)
if num_blocks == 0:
return 0
elif num_blocks * seq.block_size < seq.history_len:
return seq.block_size
return seq.history_len % seq.block_size
def can_allocate(self, msg: SchedulerSequence):
"""Return if physical block can be allocated for given message."""
num_required_blocks = self.num_required_blocks(msg)
num_free_phy = self.get_num_free_gpu_blocks()
if msg.adapter_name is not None:
adapter = ADAPTER_MANAGER.get_adapter(msg.adapter_name)
num_required_blocks += self.num_required_blocks(adapter)
return num_required_blocks <= num_free_phy
def allocate_msg(self, msg: SchedulerSequence):
"""Allocate physical blocks for given message according to logical
blocks."""
logical_blocks = msg.logical_blocks
num_required_blocks = self.num_required_blocks(msg)
if num_required_blocks > 0:
blocks = self.allocator.allocate(num_required_blocks, 'gpu')
logical_blocks.append(blocks)
def allocate_adapter(self, adapter: SchedulerAdapter):
"""Allocate cpu blocks for given adapter."""
num_required_blocks = self.num_required_blocks(adapter)
if num_required_blocks > 0:
blocks = self.allocator.allocate(num_required_blocks, 'cpu')
adapter.logical_blocks.append(blocks)
def free(self, msg: SchedulerSequence):
"""Free all physical blocks allocated for the session."""
self.allocator.free(msg.logical_blocks.get_real_blocks())
msg.logical_blocks.reset()
def can_append_slot(self, msg: SchedulerSequence):
"""Return true if the message can append new slot."""
return self.can_allocate(msg)
def append_slot(self, msg: SchedulerSequence):
"""Append new slot to message."""
return self.allocate(msg)
def can_fork(self, from_msg: SchedulerSequence):
"""Return true if blocks can be folked."""
logical_blocks = from_msg.logical_blocks
if self.last_block_size(from_msg) == from_msg.block_size:
return True
cpu_mem_offset = self.allocator.cpu_mem_offset()
phy_block = self.allocator.get_physical_blocks(logical_blocks[-1])
if phy_block < cpu_mem_offset:
device = 'gpu'
else:
device = 'cpu'
phy_allocator = self.allocator.get_phy_allocator(device)
return phy_allocator.get_num_free_blocks() >= 1
def fork(self, from_msg: SchedulerSequence, to_msg: SchedulerSequence):
"""Fork new message."""
def _copy_lask_block(logical_blocks, copy_map):
cpu_mem_offset = self.allocator.cpu_mem_offset()
phy_block = self.allocator.get_physical_blocks(logical_blocks[-1])
if phy_block < cpu_mem_offset:
device = 'gpu'
else:
device = 'cpu'
block = self.allocator.allocate(1, device)
new_phy_block = self.allocator.get_physical_blocks(block[0])
copy_map[phy_block] = new_phy_block
return block[0]
logical_blocks = from_msg.logical_blocks
copy_map: Dict[int, int] = dict()
if self.last_block_size(from_msg) == from_msg.block_size:
self.allocator.add_ref_count(logical_blocks, 1)
else:
new_logical_blocks = logical_blocks.clone()
self.allocator.add_ref_count(new_logical_blocks[:-1], 1)
block = _copy_lask_block(logical_blocks, copy_map)
new_logical_blocks[-1] = block
to_msg.logical_blocks = new_logical_blocks
return copy_map
def try_swap_out(self, msg: Union[SchedulerSequence, SchedulerAdapter]):
"""Try swap msg out."""
swap_map = dict()
logical_blocks = msg.logical_blocks
cpu_mem_offset = self.allocator.cpu_mem_offset()
phy_blocks = self.allocator.get_physical_blocks(logical_blocks)
cpu_allocator = self.allocator.get_phy_allocator('cpu')
gpu_allocator = self.allocator.get_phy_allocator('gpu')
def _can_swap():
"""check swap."""
if len(logical_blocks) == 0:
return False
# we only support all blocks of a sequence on same device
if phy_blocks[0] >= cpu_mem_offset:
return False
# no free blocks
num_free = self.get_num_free_cpu_blocks()
if num_free < len(phy_blocks):
return False
# don't swap sequence with multiple reference
ref_count = gpu_allocator.get_ref_count(phy_blocks)
if np.count_nonzero(ref_count != 1) > 0:
return False
return True
def _do_swap():
"""perform swap."""
new_blocks = cpu_allocator.allocate(len(logical_blocks))
old_blocks = phy_blocks
swap_map = dict(zip(old_blocks, new_blocks - self.num_gpu_blocks))
gpu_allocator.free(old_blocks)
self.allocator.update_phy_map(logical_blocks.get_real_blocks(),
new_blocks)
if isinstance(msg, SchedulerAdapter):
msg.active(False)
return True, swap_map
if not _can_swap():
return False, swap_map
else:
return _do_swap()
def try_swap_in(self, msg: Union[SchedulerSequence, SchedulerAdapter]):
"""Try swap msg in."""
swap_map = dict()
logical_blocks = msg.logical_blocks
cpu_mem_offset = self.allocator.cpu_mem_offset()
phy_blocks = self.allocator.get_physical_blocks(logical_blocks)
cpu_allocator = self.allocator.get_phy_allocator('cpu')
gpu_allocator = self.allocator.get_phy_allocator('gpu')
def _can_swap():
"""check swap."""
if len(logical_blocks) == 0:
return False
# we only support all blocks of a sequence on same device
if phy_blocks[0] < cpu_mem_offset:
return False
# no free blocks
num_free = self.get_num_free_gpu_blocks()
if num_free < len(phy_blocks):
return False
# don't swap sequence with multiple reference
ref_count = cpu_allocator.get_ref_count(phy_blocks)
if np.count_nonzero(ref_count != 1) > 0:
return False
return True
def _do_swap():
"""perform swap."""
new_blocks = gpu_allocator.allocate(len(logical_blocks))
old_blocks = phy_blocks
swap_map = dict(zip(old_blocks - self.num_gpu_blocks, new_blocks))
cpu_allocator.free(old_blocks)
self.allocator.update_phy_map(logical_blocks.get_real_blocks(),
new_blocks)
if isinstance(msg, SchedulerAdapter):
msg.active(True)
return True, swap_map
if not _can_swap():
return False, swap_map
else:
return _do_swap()
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
import numpy as np
from ...adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter
from ...block import LogicalTokenBlocks
from ...messages import SchedulerSequence
from .default_block_manager import DefaultBlockManager
BlockTable = np.ndarray
def _div_up(x, n):
"""perform div up."""
return (x + n - 1) // n
def _last_block_size(history_len: int, block_size: int):
"""last block size."""
last = history_len % block_size
last = last if last != 0 else block_size
return last
def _num_blocks_to_drop(seq: SchedulerSequence, window_size: int):
"""num blocks to free."""
if seq.history_len <= window_size:
return 0
block_size = seq.block_size
history_len = seq.history_len
num_blocks = len(seq.logical_blocks)
win_start_block_id = (history_len - window_size) // block_size
win_end_block_id = (history_len - 1) // block_size
num_win_blocks = win_end_block_id - win_start_block_id + 1
return max(0, num_blocks - num_win_blocks)
class WindowBlockManager(DefaultBlockManager):
"""Manage the usage of blocks, generate block tables.
Args:
num_gpu_blocks (int): number of gpu blocks.
num_cpu_blocks (int): number of cpu blocks.
"""
def __init__(self, num_gpu_blocks: int, num_cpu_blocks: int,
window_size: int):
super().__init__(num_gpu_blocks, num_cpu_blocks)
assert window_size > 0, ('expect window size > 0, '
f'but get window_size = {window_size}')
self.window_size = window_size
@classmethod
def num_required_blocks(cls, obj: Union[SchedulerSequence,
SchedulerAdapter]):
"""get num required blocks."""
def __num_req_seq(seq: SchedulerSequence):
"""get num required seq blocks."""
block_size = seq.block_size
lb_tokens = cls.last_block_size(seq)
lb_remain_tokens = 0
if len(seq.logical_blocks) > 0:
lb_remain_tokens = block_size - lb_tokens
num_input_tokens = len(seq.token_ids)
num_req_tokens = max(0, num_input_tokens - lb_remain_tokens)
return _div_up(num_req_tokens, block_size)
def __num_req_adapter(adapter: SchedulerAdapter):
"""get num required adapter blocks."""
if adapter.is_actived():
return 0
else:
return adapter.rank * len(adapter.target_modules)
if isinstance(obj, SchedulerSequence):
return __num_req_seq(obj)
else:
return __num_req_adapter(obj)
@classmethod
def last_block_size(cls, seq: SchedulerSequence) -> int:
"""get last block size."""
num_blocks = len(seq.logical_blocks)
if num_blocks == 0:
return 0
return _last_block_size(seq.history_len, seq.block_size)
def can_allocate(self, msg: SchedulerSequence):
"""Return if physical block can be allocated for given message."""
num_drop_blocks = _num_blocks_to_drop(msg, self.window_size)
num_required_blocks = self.num_required_blocks(msg)
num_free_phy = self.get_num_free_gpu_blocks()
if msg.adapter_name is not None:
adapter = ADAPTER_MANAGER.get_adapter(msg.adapter_name)
num_required_blocks += self.num_required_blocks(adapter)
return num_required_blocks <= num_free_phy + num_drop_blocks
def allocate_msg(self, msg: SchedulerSequence):
"""Allocate physical blocks for given message according to logical
blocks."""
logical_blocks = msg.logical_blocks
def __get_droped_blocks(num_drop_blocks):
"""get dropped blocks."""
nonlocal logical_blocks
droped_blocks = None
if num_drop_blocks > 0:
remain_blocks = logical_blocks[num_drop_blocks:]
droped_blocks = logical_blocks[:num_drop_blocks]
logical_blocks = LogicalTokenBlocks(remain_blocks)
msg.logical_blocks = logical_blocks
return droped_blocks
def __reuse_droped_blocks(num_required_blocks, num_drop_blocks,
droped_blocks):
"""reuse dropped blocks."""
num_used_blocks = min(num_drop_blocks - num_required_blocks,
num_required_blocks)
if num_used_blocks > 0:
reused_blocks = droped_blocks[:num_used_blocks]
else:
reused_blocks = droped_blocks
logical_blocks.append(reused_blocks)
if num_used_blocks > 0:
droped_blocks = droped_blocks[num_used_blocks:]
else:
num_used_blocks = num_drop_blocks
droped_blocks = None
num_required_blocks = num_required_blocks - num_used_blocks
return num_required_blocks, droped_blocks
num_drop_blocks = _num_blocks_to_drop(msg, self.window_size)
num_required_blocks = self.num_required_blocks(msg)
droped_blocks = __get_droped_blocks(num_drop_blocks)
if num_required_blocks > 0:
if num_drop_blocks > 0:
num_required_blocks, droped_blocks = __reuse_droped_blocks(
num_required_blocks, num_drop_blocks, droped_blocks)
if num_required_blocks > 0:
blocks = self.allocator.allocate(num_required_blocks, 'gpu')
logical_blocks.append(blocks)
# drop unused blocks
if droped_blocks is not None:
self.allocator.free(droped_blocks)
def allocate_adapter(self, adapter: SchedulerAdapter):
"""Allocate cpu blocks for given adapter."""
num_required_blocks = self.num_required_blocks(adapter)
if num_required_blocks > 0:
blocks = self.allocator.allocate(num_required_blocks, 'cpu')
adapter.logical_blocks.append(blocks)
def free(self, msg: SchedulerSequence):
"""Free all physical blocks allocated for the session."""
self.allocator.free(msg.logical_blocks.get_real_blocks())
msg.logical_blocks.reset()
def can_append_slot(self, msg: SchedulerSequence):
"""Return true if the message can append new slot."""
return self.can_allocate(msg)
def append_slot(self, msg: SchedulerSequence):
"""Append new slot to message."""
return self.allocate(msg)
# Copyright (c) OpenMMLab. All rights reserved.
from .copy_eviction_helper import CopyEvictionHelper
from .recompute_eviction_helper import RecomputeEvictionHelper
__all__ = ['CopyEvictionHelper', 'RecomputeEvictionHelper']
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
from ...messages import SchedulerSequence
from ..block_manager import BaseBlockManager
SeqList = List[SchedulerSequence]
class BaseEvictionHelper:
"""Base eviction helper."""
def __init__(self, block_manager: BaseBlockManager):
self.block_manager: BaseBlockManager = block_manager
def need_swap_in(self, seq: SchedulerSequence):
"""sequence need swap in."""
raise NotImplementedError('Not implemented.')
def try_swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int,
int]):
"""try swap out."""
raise NotImplementedError('Not implemented.')
def try_swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
"""try swap in."""
raise NotImplementedError('Not implemented.')
def try_swap_out_seqs(self, seqs: SeqList, swap_out_map: Dict[int, int]):
"""try swap sequence out."""
for seq in reversed(seqs):
if self.try_swap_out(seq, swap_out_map):
return True
return False
def try_swap_out_unused(self, hanging: SeqList, waiting: SeqList,
swap_out_map: Dict[int, int]):
"""try swap out hanging and waiting sequence."""
if self.try_swap_out_seqs(hanging, swap_out_map):
return True
else:
return self.try_swap_out_seqs(waiting, swap_out_map)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from ...messages import SchedulerSequence
from .base_eviction_helper import BaseEvictionHelper
class CopyEvictionHelper(BaseEvictionHelper):
"""Copy to host memory eviction."""
def __init__(self, block_manager):
super().__init__(block_manager)
def need_swap_in(self, seq: SchedulerSequence):
"""sequence need swap in."""
return self.block_manager.on_device(seq, 'cpu')
def try_swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int,
int]):
"""try swap out."""
success, swap_map = self.block_manager.try_swap_out(seq)
if success:
swap_out_map.update(swap_map)
return success
def try_swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
"""try swap in."""
success, swap_map = self.block_manager.try_swap_in(seq)
if success:
swap_in_map.update(swap_map)
return success
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict
from ...messages import SchedulerSequence
from .base_eviction_helper import BaseEvictionHelper
class RecomputeEvictionHelper(BaseEvictionHelper):
"""recompute eviction."""
def __init__(self, block_manager):
super().__init__(block_manager)
def need_swap_in(self, seq: SchedulerSequence):
"""sequence need swap in."""
return False
def swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
"""sequence swap in."""
self.block_manager.allocate(seq)
def swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int, int]):
"""sequence swap out."""
self.block_manager.free(seq)
seq.set_step(0)
seq.logical_blocks.reset()
def try_swap_out(self, seq: SchedulerSequence, swap_out_map: Dict[int,
int]):
"""try swap out."""
if seq.history_len > 0:
self.swap_out(seq, swap_out_map)
return True
else:
return False
def try_swap_in(self, seq: SchedulerSequence, swap_in_map: Dict[int, int]):
"""try swap in."""
if self.block_manager.can_allocate(seq):
self.swap_in(seq, swap_in_map)
return True
else:
return False
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/vllm-project/vllm
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, List, Set, Union
from lmdeploy.utils import get_logger, logging_timer
from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter
from ..config import CacheConfig, SchedulerConfig
from ..messages import MessageStatus, SchedulerSequence, SchedulerSession
from .block_manager import DefaultBlockManager as BlockManager
from .block_manager import build_block_manager
logger = get_logger('lmdeploy')
SeqList = List[SchedulerSequence]
AdapterList = List[SchedulerAdapter]
def _find_seq_with_session_id(group: SeqList, session_id: int):
return [seq for seq in group if seq.session_id == session_id]
@dataclass
class SchedulerOutput:
"""Output of schedule."""
running: SeqList
swap_in_map: Dict[int, int]
swap_out_map: Dict[int, int]
copy_map: Dict[int, int]
adapters: AdapterList
class Scheduler:
"""Tools to schedule next step.
Args:
scheduler_config (SchedulerConfig): The config of scheduler.
cache_config (CacheConfig): The config of cache info.
"""
def __init__(self, scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> None:
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.waiting: SeqList = []
self.running: SeqList = []
self.hanging: SeqList = []
self.sessions: Dict[int, SchedulerSession] = OrderedDict()
self.actived_adapters: Set[str] = set()
self.block_manager = build_block_manager(cache_config)
self.eviction_helper = self.build_eviction_helper(
self.scheduler_config.eviction_type, self.block_manager)
def build_eviction_helper(ctx, eviction_type: str,
block_manager: BlockManager):
if eviction_type == 'copy':
from .eviction_helper import CopyEvictionHelper
return CopyEvictionHelper(block_manager)
elif eviction_type == 'recompute':
from .eviction_helper import RecomputeEvictionHelper
return RecomputeEvictionHelper(block_manager)
else:
raise TypeError(f'Unknown eviction type: {eviction_type}')
def _set_message_status(self, message: SchedulerSequence,
status: MessageStatus):
"""Set status of message.
Args:
message (SchedulerSequence): message to setup status.
status (MessageStatus): New message status.
"""
message.status = status
def add_session(self, session_id: int):
"""Add new session.
Args:
session_id (int): New session id.
"""
assert session_id not in self.sessions
session = SchedulerSession(session_id, self.cache_config.block_size)
self.sessions[session_id] = session
return session
def add_sequence(self, seq: SchedulerSequence):
"""Add sequence.
Args:
seq (SchedulerSequence): New sequence.
"""
assert (seq.session_id
in self.sessions), f'Unknown session id {seq.session_id}'
# push message to waiting queue
self._set_message_status(seq, MessageStatus.WAITING)
self.waiting.append(seq)
def add_adapter(self, adapter_path: str, adapter_name: str):
"""Add adapter.
Args:
adapter_path (str): The path of adapter.
adapter_name (str): The name of the adapter.
"""
adapter = ADAPTER_MANAGER.add_adapter_from_pretrained(
adapter_path, adapter_name=adapter_name)
self.block_manager.allocate_adapter(adapter)
block_table = self.block_manager.get_block_table(
adapter) - self.block_manager.num_gpu_blocks
return adapter.build_weight_map(block_table)
@logging_timer('SchedulePrefilling', logger)
def _schedule_prefill(self):
"""Schedule for prefilling."""
max_batches = self.scheduler_config.max_batches - len(self.running)
block_manager = self.block_manager
eviction_helper = self.eviction_helper
swap_out_map: Dict[int, int] = dict()
swap_in_map: Dict[int, int] = dict()
copy_map: Dict[int, int] = dict()
running: SeqList = []
required_adapters = set(seq.adapter_name for seq in self.running)
max_adapters = self.scheduler_config.max_active_adapters - len(
required_adapters)
def _to_running(seq: SchedulerSequence):
"""to running."""
self._set_message_status(seq, MessageStatus.RUNNING)
running.append(seq)
def _evict_until_can_append(seq: SchedulerSequence):
"""evict until can append."""
while eviction_helper.try_swap_out_unused(self.hanging,
self.waiting[1:],
swap_out_map):
if block_manager.can_append_slot(seq):
return True
return False
def _reorder_waiting():
"""reorder waiting."""
self.waiting = sorted(self.waiting,
key=lambda seq: seq.arrive_time)
def _active_adapter(adapter_name):
"""active adapter of a seq."""
if adapter_name is None:
required_adapters.add(adapter_name)
return
if adapter_name not in required_adapters:
adapter = ADAPTER_MANAGER.get_adapter(adapter_name)
if not adapter.is_actived():
success, tmp_map = self.block_manager.try_swap_in(adapter)
assert success
swap_in_map.update(tmp_map)
required_adapters.add(adapter_name)
def _deactive_adapter(adapter_name):
"""deactive_adapter."""
if adapter_name is None:
return
adapter = ADAPTER_MANAGER.get_adapter(adapter_name)
if adapter.is_actived():
success, tmp_map = self.block_manager.try_swap_out(adapter)
assert success
swap_out_map.update(tmp_map)
if len(running) >= max_batches or len(self.waiting) == 0:
return running, swap_in_map, swap_out_map, copy_map
_reorder_waiting()
while len(self.waiting) > 0 and len(running) < max_batches:
seq = self.waiting[0]
# limit number of adapters
if len(required_adapters) >= max_adapters:
if seq.adapter_name not in required_adapters:
break
if not block_manager.can_allocate(seq):
if not _evict_until_can_append(seq):
break
if eviction_helper.need_swap_in(seq):
if not eviction_helper.try_swap_in(seq, swap_in_map):
break
# allocate session memory
block_manager.allocate(seq)
_active_adapter(seq.adapter_name)
self.waiting.pop(0)
_to_running(seq)
deactive_adapters = self.actived_adapters.difference(required_adapters)
for adapter_name in deactive_adapters:
_deactive_adapter(adapter_name)
self.actived_adapters = required_adapters
self.running += running
return running, swap_in_map, swap_out_map, copy_map
@logging_timer('ScheduleDecoding', logger)
def _schedule_decoding(self):
"""schedule decoding."""
assert len(self.running) != 0
block_manager = self.block_manager
eviction_helper = self.eviction_helper
swap_out_map: Dict[int, int] = dict()
swap_in_map: Dict[int, int] = dict()
copy_map: Dict[int, int] = dict()
running: SeqList = []
def _to_running(seq: SchedulerSequence):
"""to running."""
self._set_message_status(seq, MessageStatus.RUNNING)
running.append(seq)
def _try_append_slot(seq):
"""try append slot."""
if self.block_manager.num_required_blocks(seq) == 0:
_to_running(seq)
return True
if block_manager.can_append_slot(seq):
block_manager.append_slot(seq)
_to_running(seq)
return True
return False
def _evict_until_can_append(seq: SchedulerSequence):
"""evict until can append."""
while eviction_helper.try_swap_out_unused(self.hanging,
self.waiting,
swap_out_map):
if block_manager.can_append_slot(seq):
return True
return False
# 1. running
for seq in self.running:
# token + 1
if len(seq.logical_blocks) > self.block_manager.num_gpu_blocks:
# Reach max gpu cache size.
logger.warning(f'session[{seq.session_id}] '
f'sequence[{seq.seq_id}] '
'reach max gpu size.')
self._set_message_status(seq, MessageStatus.ABORTED)
self.block_manager.free(seq)
if not _try_append_slot(seq):
# try free unused cache from waiting
if _evict_until_can_append(seq):
_try_append_slot(seq)
else:
# move to waiting
self._set_message_status(seq, MessageStatus.WAITING)
self.waiting.insert(0, seq)
self.running = running
return running, swap_in_map, swap_out_map, copy_map
@classmethod
def _get_adapter_list(cls, adapter_names: List[str]):
adapters = [
ADAPTER_MANAGER.get_adapter(name) for name in adapter_names
]
return adapters
def schedule(self, is_prefill: bool):
"""Schedule inputs for next steps."""
if is_prefill:
output = self._schedule_prefill()
else:
output = self._schedule_decoding()
running, swap_in_map, swap_out_map, copy_map = output
adapters = self._get_adapter_list(self.actived_adapters)
return SchedulerOutput(running=running,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map,
copy_map=copy_map,
adapters=adapters)
def _set_session_status(self, session_id: int, status: MessageStatus):
"""Setup the status of session.
Args:
session_id (int): The session id.
status (MessageStatus): New status.
"""
assert session_id in self.sessions
session = self.sessions[session_id]
session.status = status
running_seq = _find_seq_with_session_id(self.running, session_id)
waiting_seq = _find_seq_with_session_id(self.waiting, session_id)
hanging_seq = _find_seq_with_session_id(self.hanging, session_id)
for seq in running_seq + waiting_seq + hanging_seq:
seq.status = status
def stop_session(self, session_id: int):
"""Stop session.
Args:
session_id (int): The session id.
"""
self._set_session_status(session_id, MessageStatus.STOPPED)
def end_session(self, session_id: int):
"""End session.
Args:
session_id (int): The session id.
"""
self._set_session_status(session_id, MessageStatus.ENDED)
def has_unfinished(self):
"""Check if there are any unfinished message."""
return self.waiting or self.running
def has_running(self):
return len(self.running) > 0
def _remove_sequence(self, seq: SchedulerSequence):
"""Remove sequence(unsafe)
Args:
seq (SchedulerSequence): sequence to remove
"""
self.block_manager.free(seq)
seq.session.sequences.pop(seq.seq_id)
def update(self):
"""Update scheduler status after one step.
A full step inference should include:
0. end unused sequence
1. schedule the running sequence
2. forward with the running sequence
3. update scheduler status
"""
seq_to_remove = []
session_id_to_remove = set()
def _update_queue(group: SeqList, expect_status: MessageStatus):
for seq in group:
if seq.status == expect_status:
continue
if seq.status == MessageStatus.WAITING:
self.waiting.append(seq)
if seq.status == MessageStatus.STOPPED:
self.hanging.append(seq)
# remove stopped session
if seq.status == MessageStatus.ENDED:
seq_to_remove.append(seq)
return [seq for seq in group if seq.status == expect_status]
self.running = _update_queue(self.running, MessageStatus.RUNNING)
self.waiting = _update_queue(self.waiting, MessageStatus.WAITING)
self.hanging = _update_queue(self.hanging, MessageStatus.STOPPED)
for session_id, session in self.sessions.items():
if session.status == MessageStatus.ENDED:
session_id_to_remove.add(session_id)
# remove seqs
for seq in seq_to_remove:
self._remove_sequence(seq)
# remove sessions
for session_id in session_id_to_remove:
self.sessions.pop(session_id)
def get_block_tables(self, seqs: Union[SeqList, AdapterList]):
"""get block table of the sequences."""
return [self.block_manager.get_block_table(seq) for seq in seqs]
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os
import random
from lmdeploy.messages import EngineGenerationConfig, PytorchEngineConfig
from lmdeploy.model import MODELS
from lmdeploy.tokenizer import Tokenizer
from . import engine as tm
os.environ['TM_LOG_LEVEL'] = 'ERROR'
class LLM(object):
"""LLM."""
def __init__(self,
model_path: str,
model_name: str,
tp: int = 1,
max_session_len=40000) -> None:
self.tokenizer = Tokenizer(model_path)
self.tm_model = tm.Engine(model_path,
engine_config=PytorchEngineConfig(
tp=tp,
session_len=max_session_len,
max_batch_size=64,
max_prefill_token_num=8192,
),
trust_remote_code=True)
self.generator = self.tm_model.create_instance()
self.model = MODELS.get(model_name)()
seed = random.getrandbits(64)
self.gen_config = EngineGenerationConfig(
max_new_tokens=32,
top_k=40,
top_p=0.8,
temperature=0.8,
repetition_penalty=1.0,
ignore_eos=False,
random_seed=seed,
)
self.session_id = 1
def say(self, question: str):
"""say."""
prompt = self.model.get_prompt(question, True)
input_ids = self.tokenizer.encode(prompt)
_, token_ids, __ = self.generator.infer(session_id=self.session_id,
input_ids=input_ids,
gen_config=self.gen_config)
response = self.tokenizer.decode(token_ids)
self.generator.end(self.session_id)
self.session_id += 1
return response
def tokenize(self, question: str):
"""tokenize."""
prompt = self.model.get_prompt(question, True)
return self.tokenizer.encode(prompt)
def valid_str(string, coding='utf-8'):
"""decode text according to its encoding type."""
invalid_chars = [b'\xef\xbf\xbd']
bstr = bytes(string, coding)
for invalid_char in invalid_chars:
bstr = bstr.replace(invalid_char, b'')
ret = bstr.decode(encoding=coding, errors='ignore')
return ret
def parse_config():
"""parse arguments."""
parser = argparse.ArgumentParser(description='arg parser')
parser.add_argument(
'--model_path',
type=str,
default='/models/openbuddy-llama2-13b-v8.1-fp16',
help='LLM path, use /models/openbuddy-llama2-13b-v8.1-fp16 by default')
parser.add_argument('--model_name',
type=str,
default='llama2',
help='LLM type name, use llama2 by default')
parser.add_argument('--max_tokens',
type=int,
default=50000,
help='maximum token length for evaluation')
parser.add_argument('--interval',
type=int,
default=1024,
help='interval for evaluation')
parser.add_argument('--num_tests',
type=int,
default=1,
help='number of repeat testing for each length')
args = parser.parse_args()
return args
# copy from https://github.com/dvlab-research/LongLoRA/blob/main/passkey_retrivial.py # noqa: E501
def generate_prompt_landmark(n_garbage=60000, seed=666):
"""Generates a text file and inserts an passkey at a random position."""
from numpy import random as nprandom
rnd_state = nprandom.get_state()
nprandom.seed(seed)
n_garbage_prefix = nprandom.randint(0, n_garbage)
n_garbage_suffix = n_garbage - n_garbage_prefix
task_description = 'There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.' # noqa: E501
garbage = 'The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.' # noqa: E501
garbage_num = n_garbage // (len(garbage) + 1) + 1
garbage_inf = ' '.join([garbage] * garbage_num)
assert len(garbage_inf) >= n_garbage
garbage_prefix = garbage_inf[:n_garbage_prefix]
garbage_suffix = garbage_inf[:n_garbage_suffix]
pass_key = nprandom.randint(1, 50000)
information_line = f'The pass key is {pass_key}. Remember it. {pass_key} is the pass key.' # noqa: E501
final_question = 'What is the pass key? The pass key is'
lines = [
task_description,
garbage_prefix,
information_line,
garbage_suffix,
final_question,
]
nprandom.set_state(rnd_state)
return '\n'.join(lines), str(pass_key)
def main(args):
"""main."""
# Load model and tokenizer
llm = LLM(model_path=args.model_path,
model_name=args.model_name,
max_session_len=args.max_tokens)
all_accuries = {}
# This is a rough ratio to control the number of texts and tokens
for val in range(4096, args.max_tokens, args.interval):
n_garbage = int(3.75 * val // 1024 * 1024)
assert n_garbage > 0
passed_tests = 0
total_tokens = 0
for j in range(args.num_tests):
question, pass_key = generate_prompt_landmark(n_garbage=n_garbage,
seed=(val + j))
response = llm.say(question)
if pass_key in response:
passed_tests += 1
total_tokens += len(llm.tokenize(question=question))
avg_tokens = total_tokens // args.num_tests
accuracy = passed_tests / args.num_tests
print('accuracy on the token length %d is %f' % (avg_tokens, accuracy))
all_accuries[str(avg_tokens)] = accuracy
print('accuries over tokens', all_accuries)
if __name__ == '__main__':
args = parse_config()
main(args)
# Copyright (c) OpenMMLab. All rights reserved.
from transformers import AutoConfig
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
_SUPPORTED_ARCHS = dict(
# baichuan-7b
BaiChuanForCausalLM=False,
# baichuan2-7b, baichuan-13b, baichuan2-13b
BaichuanForCausalLM=True,
# chatglm2-6b, chatglm3-6b
ChatGLMModel=True,
# deepseek-moe
DeepseekForCausalLM=True,
# falcon-7b
FalconForCausalLM=True,
# gemma-7b
GemmaForCausalLM=True,
# internlm
InternLMForCausalLM=True,
# internlm2
InternLM2ForCausalLM=True,
# internlm-xcomposer
InternLMXComposerForCausalLM=False,
# internlm2-xcomposer
InternLM2XComposerForCausalLM=False,
# llama, llama2, alpaca, vicuna, codellama, ultracm, yi,
# deepseek-coder, deepseek-llm
LlamaForCausalLM=True,
# Mistral-7B
MistralForCausalLM=True,
# Mixtral-8x7B
MixtralForCausalLM=True,
# Qwen 7B-72B, Qwen-VL-7B
QWenLMHeadModel=False,
# Qwen1.5 7B-72B
Qwen2ForCausalLM=True,
)
def is_supported(model_path: str):
"""Check whether supported by pytorch engine.
Args:
model_path (str): the path of a model.
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii).
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "internlm/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
Returns:
support_by_torch (bool): Whether input model is supported by pytorch engine
""" # noqa: E501
import os
support_by_torch = False
triton_model_path = os.path.join(model_path, 'triton_models')
if os.path.exists(triton_model_path):
logger.warning(f'{model_path} seems to be a turbomind workspace, '
'which can only be ran with turbomind engine.')
else:
cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
if hasattr(cfg, 'architectures'):
arch = cfg.architectures[0]
elif hasattr(cfg,
'auto_map') and 'AutoModelForCausalLM' in cfg.auto_map:
arch = cfg.auto_map['AutoModelForCausalLM'].split('.')[-1]
else:
raise RuntimeError(
f'Could not find model architecture from config: {cfg}')
if arch in _SUPPORTED_ARCHS:
support_by_torch = _SUPPORTED_ARCHS[arch]
# special cases
if arch == 'BaichuanForCausalLM':
# baichuan-13B not supported by pytorch
if cfg.num_attention_heads == 40 and cfg.vocab_size == 64000:
support_by_torch = False
return support_by_torch
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import torch
def continuous_tensor(inputs: torch.Tensor, seq_length: torch.LongTensor):
"""Convert batched tensor to continuous tensor.
Args:
inputs (Tensor): batched tensor.
seq_length (Tensor): length of each sequence.
Return:
Tensor: continuoused tensor.
"""
assert inputs.dim() > 1
if inputs.size(1) == 1:
return inputs.reshape(1, -1)
inputs = [inp[:slen] for inp, slen in zip(inputs, seq_length)]
inputs = torch.cat(inputs).unsqueeze(0)
return inputs
def batch_tensor(inputs: torch.Tensor, seq_length: torch.LongTensor):
"""Convert continuoused tensor to batched tensor.
Args:
inputs (Tensor): continuoused tensor.
seq_length (Tensor): length of each sequence.
Return:
Tensor: batched tensor.
"""
from torch.nn.utils.rnn import pad_sequence
end_loc = seq_length.cumsum(0)
start_loc = end_loc - seq_length
inputs = [inputs[0, sloc:eloc] for sloc, eloc in zip(start_loc, end_loc)]
inputs = pad_sequence(inputs, batch_first=True)
return inputs
def page_cache(paged_cache: torch.Tensor,
batched_cache: torch.Tensor,
cache_length: torch.Tensor,
block_offsets: torch.Tensor,
permute_head: bool = True):
"""Convert batched cache to paged cache.
Args:
paged_cache (Tensor): Output paged cache.
batched_cache (Tensor): Input batched cache.
cache_length (Tensor): length of the cache.
block_offsets (Tensor): Offset of each blocks.
"""
assert block_offsets.dim() == 2
block_size = paged_cache.size(1)
batch_size = batched_cache.size(0)
if permute_head:
batched_cache = batched_cache.permute(0, 2, 1, 3)
for b_idx in range(batch_size):
cache_len = cache_length[b_idx]
b_cache = batched_cache[b_idx]
block_off = block_offsets[b_idx]
block_off_idx = 0
for s_start in range(0, cache_len, block_size):
s_end = min(s_start + block_size, cache_len)
s_len = s_end - s_start
b_off = block_off[block_off_idx]
paged_cache[b_off, :s_len] = b_cache[s_start:s_end]
block_off_idx += 1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, List, Tuple
import torch
from .layout_convert import continuous_tensor, page_cache
def make_model_inputs(input_ids: torch.Tensor,
block_offsets: torch.Tensor,
seq_length: torch.Tensor = None,
history_length: List[int] = None):
"""make model inputs."""
from lmdeploy.pytorch.engine.model_agent import ModelInputs
batch_size = input_ids.size(0)
max_seq_len = input_ids.size(1)
if seq_length is None:
max_seq_len = input_ids.size(1)
seq_length = torch.full((batch_size, ), max_seq_len)
input_ids = continuous_tensor(input_ids, seq_length)
if history_length is None:
history_length = [0] * batch_size
else:
assert len(history_length) == len(seq_length)
is_decoding = input_ids.size(0) == batch_size
q_start_loc = seq_length.cumsum(0) - seq_length
mask_range = torch.arange(max_seq_len)[None, :]
attention_mask = (mask_range < seq_length[:, None]).long()
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids += position_ids.new_tensor(history_length).unsqueeze(-1)
if isinstance(history_length, torch.Tensor):
history_length = history_length.tolist()
return ModelInputs(input_ids=input_ids,
seq_length=seq_length,
attention_mask=attention_mask,
block_offsets=block_offsets,
position_ids=position_ids,
q_start_loc=q_start_loc,
history_lengths=history_length,
is_decoding=is_decoding)
def make_step_context(
input_ids: torch.Tensor,
seq_length: torch.Tensor = None,
history_length: List[int] = None,
past_key_values: List[Tuple] = None,
world_size: int = 1,
device: str = 'cuda',
block_size: int = 64,
num_key_value_heads: int = 32,
head_size: int = 128,
kv_cache_dtype: torch.dtype = torch.float16,
json_config: Any = None,
):
"""make step context."""
from torch.nn.utils.rnn import pad_sequence
from lmdeploy.pytorch.engine.model_agent import StepContext
batch_size = input_ids.size(0)
max_seq_len = input_ids.size(1)
if seq_length is None:
max_seq_len = input_ids.size(1)
seq_length = torch.full((batch_size, ), max_seq_len)
if history_length is None:
history_length = [0] * batch_size
else:
assert len(history_length) == len(seq_length)
history_length = torch.tensor(history_length)
def __create_kv_caches(past_key_values):
"""create kv caches."""
total_length = seq_length + history_length
num_blocks_per_seq = (total_length + block_size - 1) // block_size
num_blocks = sum(num_blocks_per_seq)
num_caches = 1 if past_key_values is None else len(past_key_values)
cache_shape = [num_blocks, block_size, num_key_value_heads, head_size]
block_offsets_1d = torch.arange(0, num_blocks)
block_end_loc = num_blocks_per_seq.cumsum(0)
block_start_loc = block_end_loc - num_blocks_per_seq
block_offsets = [
block_offsets_1d[sloc:eloc]
for sloc, eloc in zip(block_start_loc, block_end_loc)
]
block_offsets = pad_sequence(block_offsets, batch_first=True)
kv_caches = []
for _ in range(num_caches):
k_cache = torch.empty(cache_shape,
dtype=kv_cache_dtype,
device=device)
v_cache = torch.empty_like(k_cache)
kv_caches.append((k_cache, v_cache))
return kv_caches, block_offsets
def __fill_kv_caches(kv_caches, past_key_values, block_offsets):
"""fill kv caches."""
if past_key_values is None:
return
if all(hlen == 0 for hlen in history_length):
return
num_layers = len(past_key_values)
for layer_idx in range(num_layers):
k_cache, v_cache = kv_caches[layer_idx]
past_k, past_v = past_key_values[layer_idx]
page_cache(k_cache, past_k, history_length, block_offsets)
page_cache(v_cache, past_v, history_length, block_offsets)
kv_caches, block_offsets = __create_kv_caches(past_key_values)
__fill_kv_caches(kv_caches, past_key_values, block_offsets)
history_length = history_length.tolist()
model_inputs = make_model_inputs(input_ids,
block_offsets=block_offsets,
seq_length=seq_length,
history_length=history_length)
model_inputs = model_inputs.to_device(device)
return StepContext.new(
inputs=model_inputs,
world_size=world_size,
device=device,
json_config=json_config,
kv_caches=kv_caches,
)
class ModuleIOExtractor:
"""Extract input and output of target sub module."""
def __init__(self, model: torch.nn.Module, target_module: torch.nn.Module):
def __check_target_exist():
for mod in model.modules():
if mod == target_module:
return True
return False
if not __check_target_exist():
raise RuntimeError(f'{type(target_module)} is not a sub module'
f' of {type(model)}')
self._model = model
self._target_module = target_module
def extract(self, *args, **kwargs):
"""extract."""
target_args = None
target_kwargs = None
target_output = None
def __forward_hook(module, args, kwargs, output):
"""hook."""
nonlocal target_args, target_kwargs, target_output
target_args = args
target_kwargs = kwargs
target_output = output
handle = self._target_module.register_forward_hook(__forward_hook,
with_kwargs=True)
self._model(*args, **kwargs)
handle.remove()
return target_args, target_kwargs, target_output
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import time
from dataclasses import dataclass, field
from itertools import count
from typing import List, Literal, Optional, Tuple, Union
import gradio as gr
from packaging.version import Version, parse
from PIL import Image
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig)
from lmdeploy.model import ChatTemplateConfig
from lmdeploy.pytorch.engine.request import _run_until_complete
from lmdeploy.serve.gradio.constants import CSS, THEME, disable_btn, enable_btn
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import get_logger
BATCH_SIZE = 32
logger = get_logger('lmdeploy')
if parse(gr.__version__) >= Version('4.0.0'):
que_kwargs = {'default_concurrency_limit': BATCH_SIZE}
else:
que_kwargs = {'concurrency_count': BATCH_SIZE}
@dataclass
class Session:
"""chat session.
Args:
_session_id (int): session_id for internal use.
_message (List[Tuple[Any, str]]): chat history for internal use.
_step (int): the offset of the k/v cache for internal use.
"""
_count = count()
_session_id: int = None
_message: List[Tuple[str, str]] = field(default_factory=list)
_step: int = 0
def __init__(self):
self._session_id = next(self._count)
self._message = []
self._step = 0
@property
def session_id(self):
return self._session_id
@property
def message(self):
return self._message
@property
def step(self):
return self._step
def preprocess(engine, prompt, sequence_start: bool):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
inputs = loop.run_until_complete(
engine._get_prompt_input(prompt, True, sequence_start=sequence_start))
return inputs
def run_local(model_path: str,
model_name: Optional[str] = None,
backend: Literal['turbomind', 'pytorch'] = 'turbomind',
backend_config: Optional[Union[PytorchEngineConfig,
TurbomindEngineConfig]] = None,
chat_template_config: Optional[ChatTemplateConfig] = None,
server_name: str = '0.0.0.0',
server_port: int = 6006,
tp: int = 1,
**kwargs):
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
engine = VLAsyncEngine(model_path=model_path,
model_name=model_name,
backend=backend,
backend_config=backend_config,
chat_template_config=chat_template_config,
tp=tp,
**kwargs)
def add_image(chatbot, session, file):
"""Append image to query."""
chatbot = chatbot + [((file.name, ), None)]
history = session._message
img = Image.open(file.name).convert('RGB')
# [([user, img, img], assistant), ...]
if len(history) == 0 or history[-1][-1] is not None:
history.append([[img], None])
else:
history[-1][0].append(img)
return chatbot, session
def add_text(chatbot, session, text):
"""User query."""
chatbot = chatbot + [(text, None)]
history = session._message
if len(history) == 0 or history[-1][-1] is not None:
history.append([text, None])
else:
history[-1][0].insert(0, text)
return chatbot, session, disable_btn, enable_btn
def chat(chatbot, session, max_new_tokens, top_p, top_k, temperature):
"""Chat with AI assistant."""
generator = engine.engine.create_instance()
history = session._message
sequence_start = len(history) == 1
if isinstance(history[-1][0], str):
prompt = history[-1][0]
else:
prompt = history[-1][0][0]
images = history[-1][0][1:]
prompt = (prompt, images)
logger.info('prompt: ' + str(prompt))
prompt = engine.vl_prompt_template.prompt_to_messages(prompt)
t0 = time.perf_counter()
inputs = _run_until_complete(
engine._get_prompt_input(prompt,
True,
sequence_start=sequence_start))
t1 = time.perf_counter()
logger.info('preprocess cost %.3fs' % (t1 - t0))
input_ids = inputs['input_ids']
logger.info('input_ids: ' + str(input_ids))
if len(input_ids) + session.step + max_new_tokens > engine.session_len:
gr.Warning('WARNING: exceed session max length.'
' Please restart the session by reset button.')
yield chatbot, session, enable_btn, disable_btn, enable_btn
else:
gen_config = GenerationConfig(max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
temperature=temperature)
step = session.step
state = DetokenizeState()
for outputs in generator.stream_infer(
session_id=session._session_id,
**inputs,
sequence_start=sequence_start,
step=step,
gen_config=gen_config,
stream_output=True):
_, res, tokens = outputs
response, state = engine.tokenizer.detokenize_incrementally(
res,
state,
skip_special_tokens=gen_config.skip_special_tokens)
if chatbot[-1][1] is None:
chatbot[-1][1] = ''
history[-1][1] = ''
chatbot[-1][1] += response
history[-1][1] += response
session._step = step + len(input_ids) + tokens
yield chatbot, session, disable_btn, enable_btn, disable_btn
yield chatbot, session, enable_btn, disable_btn, enable_btn
def stop(session):
"""Stop the session."""
generator = engine.engine.create_instance()
for _ in generator.stream_infer(session_id=session.session_id,
input_ids=[0],
request_output_len=0,
sequence_start=False,
sequence_end=False,
stop=True):
pass
def cancel(chatbot, session):
"""Stop the session and keey chat history."""
stop(session)
return chatbot, session, disable_btn, enable_btn, enable_btn
def reset(session):
"""Reset a new session."""
stop(session)
session._step = 0
session._message = []
return [], session, enable_btn
with gr.Blocks(css=CSS, theme=THEME) as demo:
with gr.Column(elem_id='container'):
gr.Markdown('## LMDeploy VL Playground')
chatbot = gr.Chatbot(elem_id='chatbot', label=engine.model_name)
query = gr.Textbox(placeholder='Please input the instruction',
label='Instruction')
session = gr.State()
with gr.Row():
addimg_btn = gr.UploadButton('Upload Image',
file_types=['image'])
cancel_btn = gr.Button(value='Cancel', interactive=False)
reset_btn = gr.Button(value='Reset')
with gr.Row():
max_new_tokens = gr.Slider(1,
2048,
value=512,
step=1,
label='Maximum new tokens')
top_p = gr.Slider(0.01, 1, value=0.8, step=0.01, label='Top_p')
top_k = gr.Slider(1, 100, value=50, step=1, label='Top_k')
temperature = gr.Slider(0.01,
1.5,
value=0.7,
step=0.01,
label='Temperature')
addimg_btn.upload(add_image, [chatbot, session, addimg_btn],
[chatbot, session],
show_progress=True,
queue=True)
send_event = query.submit(
add_text, [chatbot, session, query], [chatbot, session]).then(
chat,
[chatbot, session, max_new_tokens, top_p, top_k, temperature],
[chatbot, session, query, cancel_btn, reset_btn])
query.submit(lambda: gr.update(value=''), None, [query])
cancel_btn.click(cancel, [chatbot, session],
[chatbot, session, cancel_btn, reset_btn, query],
cancels=[send_event])
reset_btn.click(reset, [session], [chatbot, session, query],
cancels=[send_event])
demo.load(lambda: Session(), inputs=None, outputs=[session])
demo.queue(api_open=True, **que_kwargs, max_size=100)
demo.launch(
share=True,
server_port=server_port,
server_name=server_name,
)
if __name__ == '__main__':
import fire
fire.Fire(run_local)
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import enum
LATENCY_DEEQUE_LEN = 15
API_TIMEOUT_LEN = 100
class Strategy(enum.Enum):
"""Strategy to dispatch requests to nodes."""
RANDOM = enum.auto()
MIN_EXPECTED_LATENCY = enum.auto()
MIN_OBSERVED_LATENCY = enum.auto()
@classmethod
def from_str(cls, name):
"""get strategy from string."""
if name == 'random':
return cls.RANDOM
elif name == 'min_expected_latency':
return cls.MIN_EXPECTED_LATENCY
elif name == 'min_observed_latency':
return cls.MIN_OBSERVED_LATENCY
else:
raise ValueError(f'Invalid strategy: {name}. Supported: random, '
f'min_expected_latency, min_observed_latency.')
class ErrorCodes(enum.Enum):
"""Error codes."""
MODEL_NOT_FOUND = 10400
SERVICE_UNAVAILABLE = 10401
API_TIMEOUT = 10402
err_msg = {
ErrorCodes.MODEL_NOT_FOUND:
'The request model name does not exist in the model list.',
ErrorCodes.SERVICE_UNAVAILABLE:
'The service is unavailable now. May retry later.',
ErrorCodes.API_TIMEOUT: 'Failed to get response after a period of time'
}
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import json
import os
import os.path as osp
import random
import time
from collections import deque
from http import HTTPStatus
from typing import Deque, Dict, List, Literal, Optional, Union
import numpy as np
import requests
import uvicorn
import yaml
from fastapi import BackgroundTasks, Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field
from lmdeploy.serve.openai.api_server import (check_api_key,
create_error_response)
from lmdeploy.serve.openai.protocol import ( # noqa: E501
ChatCompletionRequest, CompletionRequest, ModelCard, ModelList,
ModelPermission)
from lmdeploy.serve.proxy.constants import (API_TIMEOUT_LEN,
LATENCY_DEEQUE_LEN, ErrorCodes,
Strategy, err_msg)
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
class Status(BaseModel):
"""Status protocol consists of models' information."""
models: Optional[List[str]] = Field(default=[], examples=[[]])
unfinished: int = 0
latency: Deque = Field(default=deque(maxlen=LATENCY_DEEQUE_LEN),
examples=[[]])
speed: Optional[int] = Field(default=None, examples=[None])
class Node(BaseModel):
"""Node protocol consists of url and status."""
url: str
status: Optional[Status] = None
class NodeManager:
"""Manage all the sub nodes.
Args:
config_path (str): the path of the config file.
strategy (str): the strategy to dispatch node to handle the requests.
- random: not fully radom, but decided by the speed of nodes.
- min_expected_latency: will compute the expected latency to
process the requests. The sooner of the node, the more requests
will be dispatched to it.
- min_observed_latency: Based on previous finished requests. The
sooner they get processed, the more requests will be dispatched
to.
"""
def __init__(self,
config_path: Optional[str] = None,
strategy: str = 'min_expected_latency') -> None:
self.nodes = dict()
self.strategy = Strategy.from_str(strategy)
self.latencies = dict()
self.config_path = osp.join(osp.dirname(osp.realpath(__file__)),
'proxy_config.yml')
if config_path is not None:
self.config_path = config_path
if osp.exists(self.config_path):
with open(self.config_path, 'r') as config_file:
self.nodes = yaml.safe_load(config_file)['nodes']
for url, status in self.nodes.items():
status = Status(**status)
self.nodes[url] = status
def update_config_file(self):
"""Update the config file."""
nodes = copy.deepcopy(self.nodes)
for url, status in nodes.items():
nodes[url] = status.model_dump()
nodes[url]['latency'] = list(status.latency)
with open(self.config_path, 'w') as config_file: # update cfg yml
yaml.dump(dict(nodes=nodes), config_file)
def add(self, node_url: str, status: Optional[Status] = None):
"""Add a node to the manager.
Args:
node_url (str): A http url. Can be the url generated by
`lmdeploy serve api_server`.
description (Dict): The description of the node. An example:
{'http://0.0.0.0:23333': {models: ['internlm-chat-7b]},
speed: -1}. The speed here can be RPM or other metric. All the
values of nodes should be the same metric.
"""
if status is None:
status = self.nodes.get(node_url, Status())
try:
from lmdeploy.serve.openai.api_client import APIClient
client = APIClient(api_server_url=node_url)
status.models = client.available_models
self.nodes[node_url] = status
except requests.exceptions.RequestException as e: # noqa
return self.handle_api_timeout(node_url)
self.update_config_file()
def remove(self, node_url: str):
"""Remove a node."""
if node_url in self.nodes.keys():
self.nodes.pop(node_url)
self.update_config_file()
@property
def model_list(self):
"""Supported model list."""
model_names = []
for node_url, node_status in self.nodes.items():
model_names.extend(node_status.models)
return model_names
@property
def status(self):
"""Return the status."""
return self.nodes
def get_node_url(self, model_name: str):
"""Add a node to the manager.
Args:
model_name (str): A http url. Can be the url generated by
`lmdeploy serve api_server`.
Return:
A node url or None.
"""
if self.strategy == Strategy.RANDOM:
urls_with_speeds, speeds, urls_without_speeds = [], [], []
for node_url, node_status in self.nodes.items():
if model_name in node_status.models:
if node_status.speed is not None:
urls_with_speeds.append(node_url)
speeds.append(node_status.speed)
else:
urls_without_speeds.append(node_url)
all_matched_urls = urls_with_speeds + urls_without_speeds
if len(all_matched_urls) == 0:
return None
# some nodes does not contain speed
# we can set them the average speed value
average_speed = sum(speeds) / len(speeds) if len(speeds) else 1
all_the_speeds = speeds + [average_speed
] * len(urls_without_speeds)
speed_sum = sum(all_the_speeds)
weights = [speed / speed_sum for speed in all_the_speeds]
index = random.choices(range(len(all_matched_urls)),
weights=weights)[0]
url = all_matched_urls[index]
return url
elif self.strategy == Strategy.MIN_EXPECTED_LATENCY:
urls_with_speeds, speeds, urls_without_speeds = [], [], []
for node_url, node_status in self.nodes.items():
if model_name in node_status.models:
if node_status.speed is not None:
urls_with_speeds.append(node_url)
speeds.append(node_status.speed)
else:
urls_without_speeds.append(node_url)
all_matched_urls = urls_with_speeds + urls_without_speeds
if len(all_matched_urls) == 0:
return None
# some nodes does not contain speed
# we can set them the average speed value
average_speed = sum(speeds) / len(speeds) if len(speeds) else 1
all_the_speeds = speeds + [average_speed
] * len(urls_without_speeds)
min_latency = float('inf')
min_index = 0
for index, speed in enumerate(all_the_speeds):
latency = self.nodes[
all_matched_urls[index]].unfinished / speed
if min_latency < latency:
min_latency = latency
min_index = index
url = all_matched_urls[min_index]
return url
elif self.strategy == Strategy.MIN_OBSERVED_LATENCY:
all_matched_urls, latencies = [], []
for node_url, node_status in self.nodes.items():
if model_name in node_status.models:
if len(node_status.latency):
latencies.append(np.mean(np.array(
node_status.latency)))
else:
latencies.append(float('inf'))
all_matched_urls.append(node_url)
if len(all_matched_urls) == 0:
return None
index = np.argmin(np.array(latencies))
return all_matched_urls[index]
else:
raise ValueError(f'Invalid strategy: {self.strategy}')
async def check_request_model(self, model_name) -> Optional[JSONResponse]:
"""Check if a request is valid."""
if model_name in self.model_list:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND, f'The model `{model_name}` does not exist.')
return ret
def handle_unavailable_model(self, model_name):
"""Handle unavailable model.
Args:
model_name (str): the model in the request.
"""
logger.info(f'no model name: {model_name}')
ret = {
'error_code': ErrorCodes.MODEL_NOT_FOUND,
'text': err_msg[ErrorCodes.MODEL_NOT_FOUND],
}
return json.dumps(ret).encode() + b'\n'
def handle_api_timeout(self, node_url):
"""Handle the api time out."""
logger.info(f'api timeout: {node_url}')
ret = {
'error_code': ErrorCodes.API_TIMEOUT,
'text': err_msg[ErrorCodes.API_TIMEOUT],
}
return json.dumps(ret).encode() + b'\n'
def stream_generate(self, request: Dict, node_url: str, node_path: str):
"""Return a generator to handle the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
node_path (str): the node path. Such as `/v1/chat/completions`.
"""
try:
response = requests.post(
node_url + node_path,
json=request,
stream=request['stream'],
timeout=API_TIMEOUT_LEN,
)
for chunk in response.iter_lines(decode_unicode=False,
delimiter=b'\n'):
if chunk:
yield chunk + b'\n'
except requests.exceptions.RequestException as e: # noqa
yield self.handle_api_timeout(node_url)
async def generate(self, request: Dict, node_url: str, node_path: str):
"""Return a the response of the input request.
Args:
request (Dict): the input request.
node_url (str): the node url.
node_path (str): the node path. Such as `/v1/chat/completions`.
"""
try:
import httpx
async with httpx.AsyncClient() as client:
response = await client.post(node_url + node_path,
json=request,
timeout=API_TIMEOUT_LEN)
return response.text
except requests.exceptions.RequestException as e: # noqa
return self.handle_api_timeout(node_url)
def pre_call(self, node_url):
"""Preprocess before the request get processed.
Args:
node_url (str): the node url.
"""
self.nodes[node_url].unfinished += 1
return time.time()
def post_call(self, node_url: str, start: int):
"""Post process after the response finished.
Args:
node_url (str): the node url.
start (int): the start time point. time.time()
"""
self.nodes[node_url].unfinished -= 1
self.nodes[node_url].latency.append(time.time() - start)
def create_background_tasks(self, url: str, start: int):
"""To create a background task.
Args:
node_url (str): the node url.
start (int): the start time point. time.time()
"""
background_tasks = BackgroundTasks()
background_tasks.add_task(self.post_call, url, start)
return background_tasks
app = FastAPI(docs_url='/')
app.add_middleware(
CORSMiddleware,
allow_origins=['*'],
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
node_manager = NodeManager()
@app.get('/v1/models', dependencies=[Depends(check_api_key)])
def available_models():
"""Show available models."""
model_cards = []
for model_name in node_manager.model_list:
model_cards.append(
ModelCard(id=model_name,
root=model_name,
permission=[ModelPermission()]))
return ModelList(data=model_cards)
@app.get('/nodes/status', dependencies=[Depends(check_api_key)])
def node_status():
"""Show nodes status."""
try:
return node_manager.status
except: # noqa
return False
@app.post('/nodes/add', dependencies=[Depends(check_api_key)])
def add_node(node: Node, raw_request: Request = None):
"""Add a node to the manager.
- url (str): A http url. Can be the url generated by
`lmdeploy serve api_server`.
- status (Dict): The description of the node. An example:
{models: ['internlm-chat-7b], speed: 1}. The speed here can be
RPM or other metric. All the values of nodes should be the same metric.
"""
try:
node_manager.add(node.url, node.status)
return 'Added successfully'
except: # noqa
return 'Failed to add, please check the input url.'
@app.post('/nodes/remove', dependencies=[Depends(check_api_key)])
def remove_node(node_url: str):
"""Show available models."""
try:
node_manager.remove(node_url)
return 'Deleted successfully'
except: # noqa
return 'Failed to delete, please check the input url.'
@app.post('/v1/chat/completions', dependencies=[Depends(check_api_key)])
async def chat_completions_v1(request: ChatCompletionRequest,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Refer to `https://platform.openai.com/docs/api-reference/chat/create`
for the API specification.
The request should be a JSON object with the following fields:
- model: model name. Available from /v1/models.
- messages: string prompt or chat history in OpenAI format. A example
for chat history is `[{"role": "user", "content":"knock knock"}]`.
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- max_tokens (int): output token nums
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
Currently we do not support the following features:
- function_call (Users should implement this by themselves)
- logit_bias (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
check_response = await node_manager.check_request_model(request.model)
if check_response is not None:
return check_response
node_url = node_manager.get_node_url(request.model)
if not node_url:
return node_manager.handle_unavailable_model(request.model)
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
response = node_manager.stream_generate(request_dict, node_url,
'/v1/chat/completions')
background_task = node_manager.create_background_tasks(node_url, start)
return StreamingResponse(response, background=background_task)
else:
response = await node_manager.generate(request_dict, node_url,
'/v1/chat/completions')
node_manager.post_call(node_url, start)
return JSONResponse(json.loads(response))
@app.post('/v1/completions', dependencies=[Depends(check_api_key)])
async def completions_v1(request: CompletionRequest,
raw_request: Request = None):
"""Completion API similar to OpenAI's API.
Go to `https://platform.openai.com/docs/api-reference/completions/create`
for the API specification.
The request should be a JSON object with the following fields:
- model (str): model name. Available from /v1/models.
- prompt (str): the input prompt.
- suffix (str): The suffix that comes after a completion of inserted text.
- max_tokens (int): output token nums
- temperature (float): to modulate the next token probability
- top_p (float): If set to float < 1, only the smallest set of most
probable tokens with probabilities that add up to top_p or higher
are kept for generation.
- n (int): How many chat completion choices to generate for each input
message. Only support one here.
- stream: whether to stream the results or not. Default to false.
- repetition_penalty (float): The parameter for repetition penalty.
1.0 means no penalty
- user (str): A unique identifier representing your end-user.
- stop (str | List[str] | None): To stop generating further
tokens. Only accept stop words that's encoded to one token idex.
Additional arguments supported by LMDeploy:
- ignore_eos (bool): indicator for ignoring eos
- session_id (int): if not specified, will set random value
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
Currently we do not support the following features:
- logprobs (not supported yet)
- presence_penalty (replaced with repetition_penalty)
- frequency_penalty (replaced with repetition_penalty)
"""
check_response = await node_manager.check_request_model(request.model)
if check_response is not None:
return check_response
node_url = node_manager.get_node_url(request.model)
if not node_url:
return node_manager.handle_unavailable_model(request.model)
request_dict = request.model_dump()
start = node_manager.pre_call(node_url)
if request.stream is True:
response = node_manager.stream_generate(request_dict, node_url,
'/v1/completions')
background_task = node_manager.create_background_tasks(node_url, start)
return StreamingResponse(response, background=background_task)
else:
response = await node_manager.generate(request_dict, node_url,
'/v1/completions')
node_manager.post_call(node_url, start)
return JSONResponse(json.loads(response))
def proxy(server_name: str = '0.0.0.0',
server_port: int = 10086,
strategy: Literal['random', 'min_expected_latency',
'min_observed_latency'] = 'min_expected_latency',
api_keys: Optional[Union[List[str], str]] = None,
ssl: bool = False,
**kwargs):
"""To launch the proxy server.
Args:
server_name (str): the server name of the proxy. Default to '0.0.0.0'.
server_port (str): the server port. Default to 10086.
strategy ('random' | 'min_expected_latency' | 'min_observed_latency'):
the strategy to dispatch requests to nodes. Default to
'min_expected_latency'
api_keys (List[str] | str | None): Optional list of API keys. Accepts string type as
a single api_key. Default to None, which means no api key applied.
ssl (bool): Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.
""" # noqa
node_manager.strategy = Strategy.from_str(strategy)
if api_keys is not None:
if isinstance(api_keys, str):
api_keys = api_keys.split(',')
from lmdeploy.serve.openai.api_server import VariableInterface
VariableInterface.api_keys = api_keys
ssl_keyfile, ssl_certfile = None, None
if ssl:
ssl_keyfile = os.environ['SSL_KEYFILE']
ssl_certfile = os.environ['SSL_CERTFILE']
uvicorn.run(app=app,
host=server_name,
port=server_port,
log_level='info',
ssl_keyfile=ssl_keyfile,
ssl_certfile=ssl_certfile)
if __name__ == '__main__':
import fire
fire.Fire(proxy)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment