Commit fe851fbc authored by zhouxiang's avatar zhouxiang
Browse files

0.2.6版本新增文件补充

parent e2d98ddc
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Dict, List
import torch
from lmdeploy.messages import (EngineGenerationConfig, PytorchEngineConfig,
ResponseType)
from lmdeploy.tokenizer import Tokenizer
from lmdeploy.utils import get_logger, get_model, logging_timer
from ..adapter.adapter import ADAPTER_MANAGER, SchedulerAdapter
from ..check_env import check_adapters, check_env, check_model
from ..config import CacheConfig, SchedulerConfig
from ..messages import MessageStatus, SamplingParam, SchedulerSequence
from ..paging import Scheduler
from .logits_process import FusedLogitsProcessor, SamplingInputs
from .model_agent import AutoModelAgent, ModelInputs
from .request import (Request, RequestManager, RequestSender, RequestType,
Response)
logger = get_logger('lmdeploy')
SeqList = List[SchedulerSequence]
AdapterList = List[SchedulerAdapter]
def _div_up(x, n):
"""perform div up."""
return (x + n - 1) // n
@dataclass
class InferOutput:
"""The output of the model inference."""
session_id: int
token_ids: List[int]
sender_id: int
req_id: int
meta: Any = None
finish: bool = False
logits: torch.Tensor = None
def _paging_adapters(adapters: dict, model_agent: AutoModelAgent,
scheduler: Scheduler):
adapters = adapters or dict()
weight_maps = []
for name, path in adapters.items():
weight_map = scheduler.add_adapter(path, name)
weight_map.block_table = torch.tensor(weight_map.block_table)
weight_maps.append(weight_map)
model_agent.paging_adapters(weight_maps)
def _tensorlize_block_offsets(block_offsets):
"""tensorlize block_offsets."""
from torch.nn.utils.rnn import pad_sequence
block_offsets = [torch.from_numpy(off) for off in block_offsets]
block_offsets = pad_sequence(block_offsets, batch_first=True)
return block_offsets
def _get_adapter_ids(seqs: SeqList, adapters: AdapterList):
"""get adapter ids."""
adapter_names_map = dict(
(ada.name, idx) for idx, ada in enumerate(adapters))
adapter_ids = [adapter_names_map[seq.adapter_name] for seq in seqs]
return adapter_ids
def _check_resp(resp: Response, state: ResponseType, warning_msg: str = None):
"""check if response has state."""
if isinstance(state, ResponseType):
state = [state]
ret = resp.type in state
if not ret and warning_msg is not None:
logger.warning(warning_msg)
return ret
def _check_resp_success(resp: Response, warning_msg: str = None):
"""check if response success."""
return _check_resp(resp, ResponseType.SUCCESS, warning_msg)
async def async_try_add_session(req_sender: RequestSender, session_id: int):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
resp = await req_sender.async_send(RequestType.ADD_SESSION,
dict(session_id=session_id))
_check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT],
(f'Can not add session {session_id} '
f'with error: {resp.type}'))
async def async_end(req_sender: RequestSender, session_id: int):
"""End the given session."""
resp = await req_sender.async_send(RequestType.END_SESSION,
dict(session_id=session_id))
_check_resp_success(resp, (f'Failed to end session: {session_id}. '
f'Error: {resp.type}.'))
async def async_cancel(req_sender: RequestSender, session_id: int):
"""Stop current streaming inference."""
resp = await req_sender.async_send(RequestType.STOP_SESSION,
dict(session_id=session_id))
_check_resp_success(resp, (f'Failed to cancel session: {session_id}. '
f'Error: {resp.type}.'))
def try_add_session(req_sender: RequestSender, session_id: int):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
resp = req_sender.send(RequestType.ADD_SESSION,
dict(session_id=session_id))
_check_resp(resp, [ResponseType.SUCCESS, ResponseType.SESSION_REPEAT],
(f'Can not add session {session_id} '
f'with error: {resp.type}'))
def end(req_sender: RequestSender, session_id: int):
"""End the given session."""
resp = req_sender.send(RequestType.END_SESSION,
dict(session_id=session_id))
_check_resp_success(resp, (f'Failed to end session: {session_id}. '
f'Error: {resp.type}.'))
def cancel(req_sender: RequestSender, session_id: int):
"""Stop current streaming inference."""
resp = req_sender.send(RequestType.STOP_SESSION,
dict(session_id=session_id))
_check_resp_success(resp, (f'Failed to cancel session: {session_id}. '
f'Error: {resp.type}.'))
class Engine:
"""The inference engine of lmdeploy pytorch.
Args:
model_path (str): The hugging face model path.
engine_config (PytorchEngineConfig): The config of the Engine.
trust_remote_code (bool): Trust remote code.
"""
def __init__(self,
model_path: str,
engine_config: PytorchEngineConfig = None,
trust_remote_code: bool = True) -> None:
check_env()
check_model(model_path, trust_remote_code)
if engine_config.adapters is not None:
check_adapters(list(engine_config.adapters.values()))
if engine_config is None:
engine_config = PytorchEngineConfig()
self.engine_config = engine_config
model_name = engine_config.model_name
tp = engine_config.tp
self.tp = tp
self.model_name = model_name
scheduler_config = SchedulerConfig(
max_batches=engine_config.max_batch_size,
max_session_len=engine_config.session_len,
eviction_type=engine_config.eviction_type,
prefill_interval=engine_config.prefill_interval)
# block_size = 1 to enable unified paging
adapters = engine_config.adapters
cache_config = CacheConfig(
block_size=engine_config.block_size,
num_cpu_blocks=engine_config.num_cpu_blocks,
num_gpu_blocks=engine_config.num_gpu_blocks,
cache_max_entry_count=engine_config.cache_max_entry_count,
max_prefill_token_num=engine_config.max_prefill_token_num)
if not os.path.exists(model_path):
model_path = get_model(model_path, engine_config.download_dir,
engine_config.revision)
self.model_agent = AutoModelAgent.from_pretrained(
model_path,
cache_config=cache_config,
trust_remote_code=trust_remote_code,
adapters=adapters,
tp=tp)
cache_config = self.model_agent.cache_config
self.scheduler = Scheduler(scheduler_config, cache_config)
if adapters:
_paging_adapters(adapters,
model_agent=self.model_agent,
scheduler=self.scheduler)
self.scheduler_config = scheduler_config
self.cache_config = cache_config
self.stream = torch.cuda.Stream()
self.req_manager = self._bind_request_manager()
# create main thread
self._start_loop()
self.req_sender = self.req_manager.build_sender()
self._create_buffers()
self.tokenizer = Tokenizer(model_path)
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
engine_config: PytorchEngineConfig = None,
trust_remote_code: bool = True,
**kwargs):
"""lmdeploy python inference engine.
Args:
pretrained_model_name_or_path (str):
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii)
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
engine_config (PytorchEngineConfig): Pytorch engine config.
trust_remote_code (bool): Trust remote code
"""
logger.debug(f'Get unexpected kwargs: {kwargs}')
return cls(model_path=pretrained_model_name_or_path,
engine_config=engine_config,
trust_remote_code=trust_remote_code)
def _create_buffers(self):
max_batches = self.scheduler_config.max_batches
# buffers to create inputs
self._q_start_loc_buf = torch.arange(max_batches)
self._attention_mask_buf = torch.ones(max_batches, 1, dtype=torch.long)
self._seq_length_buf = torch.ones(max_batches, dtype=torch.long)
def _bind_request_manager(self):
"""bind request manager."""
req_manager = RequestManager(self.engine_config.thread_safe)
req_manager.bind_func(RequestType.ADD_SESSION, self._on_add_session)
req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session)
req_manager.bind_func(RequestType.END_SESSION, self._on_end_session)
req_manager.bind_func(RequestType.ADD_MESSAGE, self._on_add_message)
return req_manager
def _start_loop(self):
"""start loop."""
return self.req_manager.start_loop(self.async_loop)
def _on_add_session(self, reqs: Request, **kwargs):
"""on add session callback."""
for req in reqs:
session_id = req.data['session_id']
resp_type = ResponseType.SESSION_REPEAT
if session_id not in self.scheduler.sessions:
self.scheduler.add_session(session_id)
resp_type = ResponseType.SUCCESS
self.req_manager.response(
Response(type=resp_type,
sender_id=req.sender_id,
req_id=req.req_id))
def _on_stop_session(self, reqs: Request, **kwargs):
"""on stop session callback."""
for req in reqs:
session_id = req.data['session_id']
resp_type = ResponseType.SESSION_NOT_EXIST
if session_id in self.scheduler.sessions:
self.scheduler.stop_session(session_id)
resp_type = ResponseType.SUCCESS
self.req_manager.response(
Response(type=resp_type,
sender_id=req.sender_id,
req_id=req.req_id))
self.scheduler.update()
def _on_end_session(self, reqs: Request, **kwargs):
"""on end session callback."""
for req in reqs:
session_id = req.data['session_id']
resp_type = ResponseType.SESSION_NOT_EXIST
if session_id in self.scheduler.sessions:
self.scheduler.end_session(session_id)
resp_type = ResponseType.SUCCESS
self.req_manager.response(
Response(type=resp_type,
sender_id=req.sender_id,
req_id=req.req_id))
self.scheduler.update()
def _on_add_message(self, reqs: Request, **kwargs):
"""on add message callback."""
def __update_bad_words(msg):
"""update bad words."""
sampling_param = msg.sampling_param
eos_token_id = self.model_config.eos_token_id
if eos_token_id not in sampling_param.stop_words:
sampling_param.stop_words.append(eos_token_id)
if sampling_param.ignore_eos:
sampling_param.bad_words.append(eos_token_id)
for req in reqs:
session_id = req.data['session_id']
if session_id not in self.scheduler.sessions:
self.req_manager.response(
Response(type=ResponseType.SESSION_NOT_EXIST,
sender_id=req.sender_id,
req_id=req.req_id))
continue
session_id = req.data['session_id']
sess = self.scheduler.sessions[session_id]
# TODO: support 1 session n sequence
if len(sess.sequences) == 0:
assert len(
req.data['token_ids']) > 0, ('Empty input is not allowed.')
sess.add_sequence(req.data['token_ids'],
sampling_param=req.data['sampling_param'],
adapter_name=req.data['adapter_name'],
return_logits=req.data.get(
'return_logits', False))
msg = next(iter(sess.sequences.values()))
__update_bad_words(msg)
self.scheduler.add_sequence(msg)
else:
msg = next(iter(sess.sequences.values()))
msg.update_token_ids(req.data['token_ids'])
msg.num_new_tokens = 0
msg.sampling_param = req.data['sampling_param']
msg.return_logits = req.data.get('return_logits', False)
msg.status = MessageStatus.WAITING
__update_bad_words(msg)
msg.sender_id = req.sender_id
msg.req_id = req.req_id
self.scheduler.update()
@property
def model_config(self):
"""model config."""
return self.model_agent.model_config
@property
def gpu_count(self):
return self.tp
@property
def session_len(self):
return self.scheduler_config.max_session_len
def create_instance(self, cuda_stream_id=0):
"""Create a turbomind instance.
Args:
cuda_stream_id(int): identity of a cuda stream
Returns:
EngineInstance: an instance of turbomind
"""
return EngineInstance(self)
async def async_add_session(self, session_id: int):
"""Add new session."""
return await async_try_add_session(self.req_sender, session_id)
def add_session(self, session_id: int):
"""Add new session."""
return try_add_session(self.req_sender, session_id)
async def async_stop_session(self, session_id: int):
"""Stop the given session."""
return await async_cancel(self.req_sender, session_id)
def stop_session(self, session_id: int):
"""Add new session."""
return cancel(self.req_sender, session_id)
async def async_end_session(self, session_id: int):
"""End the given session."""
return await async_end(self.req_sender, session_id)
def end_session(self, session_id: int):
"""Add new session."""
return end(self.req_sender, session_id)
@logging_timer('CreateModelInputs', logger)
@torch.inference_mode()
def create_model_inputs(self, messages: SeqList, adapters: AdapterList):
"""create model inputs from messages.
Args:
messages (SeqList): The input messages.
adapters (AdapterList): Adapters.
"""
def __get_history_length():
"""get history length."""
if self.model_config.sliding_window > 0:
history_lengths = []
for msg in messages:
num_real_blocks = len(msg.logical_blocks)
num_all_blocks = _div_up(msg.num_all_tokens(),
msg.block_size)
num_drop_blocks = num_all_blocks - num_real_blocks
num_drop_tokens = num_drop_blocks * msg.block_size
history_lengths.append(msg.history_len - num_drop_tokens)
return history_lengths
else:
return [msg.history_len for msg in messages]
history_lengths = __get_history_length()
token_ids = [msg.token_ids for msg in messages]
meta = messages[0].meta
if isinstance(token_ids[0], int):
token_ids = [token_ids]
batch_size = len(messages)
input_ids = torch.cat(token_ids)
is_decoding = input_ids.size(0) == batch_size
if not is_decoding:
seq_length = [tokens.size(0) for tokens in token_ids]
seq_length = torch.tensor(seq_length, dtype=torch.long)
max_seq_len = max(seq_length)
q_start_loc = seq_length.cumsum(0) - seq_length
mask_range = torch.arange(max_seq_len)[None, :]
attention_mask = (mask_range < seq_length[:, None]).long()
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids += position_ids.new_tensor(history_lengths).unsqueeze(
-1)
else:
q_start_loc = self._q_start_loc_buf[:batch_size]
attention_mask = self._attention_mask_buf[:batch_size]
seq_length = self._seq_length_buf[:batch_size]
position_ids = q_start_loc.new_tensor(history_lengths).unsqueeze(
-1)
# TODO: get block offsets is slow when block_size = 1
block_offsets = self.scheduler.get_block_tables(messages)
block_offsets = _tensorlize_block_offsets(block_offsets)
local_adapter_ids = None
global_adapter_ids = None
adapter_offsets = None
max_rank = 0
if ADAPTER_MANAGER.num_adapters() > 1:
local_adapter_ids = _get_adapter_ids(messages, adapters)
local_adapter_ids = seq_length.new_tensor(local_adapter_ids)
adapter_offsets = self.scheduler.get_block_tables(adapters)
adapter_offsets = _tensorlize_block_offsets(adapter_offsets)
global_adapter_ids = [ada.idx for ada in adapters]
global_adapter_ids = seq_length.new_tensor(global_adapter_ids)
ranks = [ada.rank for ada in adapters]
max_rank = max(ranks)
# add batch dim [bs=1, seq_len]
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
return ModelInputs(input_ids=input_ids,
seq_length=seq_length,
attention_mask=attention_mask,
block_offsets=block_offsets,
position_ids=position_ids,
q_start_loc=q_start_loc,
history_lengths=history_lengths,
is_decoding=is_decoding,
local_adapter_ids=local_adapter_ids,
global_adapter_ids=global_adapter_ids,
adapter_offsets=adapter_offsets,
max_rank=max_rank,
meta=meta)
def _stopping_criteria(self, msg: SchedulerSequence, next_token_id: int):
"""Check if the message should stop.
Args:
msg (SchedulerSequence): The input message.
next_token_id (int): The next token id from inference result.
Returns:
bool: Whether the message should be stopped.
"""
def _check_stop_word(sampling_param, next_token_id):
if sampling_param.ignore_eos:
return False
return (sampling_param.stop_words is not None
and next_token_id in sampling_param.stop_words)
def _check_request_len(msg):
return msg.num_new_tokens >= msg.sampling_param.max_new_tokens
def _check_session_len(msg, max_session_len):
if max_session_len is None:
return False
session_len = msg.num_all_tokens() + 1
return session_len >= max_session_len
sampling_param = msg.sampling_param
if _check_stop_word(sampling_param, next_token_id):
return True
if _check_request_len(msg):
return True
if _check_session_len(msg, self.scheduler_config.max_session_len):
return True
return False
@logging_timer('SamplingLogits', logger)
async def async_sampling_logits(self, logits: torch.Tensor,
running: SeqList, inputs: ModelInputs):
"""sampling logits."""
def _gather_history(seqs: SeqList, device: torch.device):
"""gather history."""
batch = len(seqs)
max_len = max(seq.history_len for seq in seqs)
output = torch.full((batch, max_len),
self.model_config.bos_token_id,
dtype=torch.int64)
for idx, seq in enumerate(seqs):
h_len = seq.history_len
h_ids = output.new_tensor(seq.history_token_ids)
output[idx, :h_len] = h_ids
return output.to(device)
is_decoding = inputs.is_decoding
# TODO: support repetition_penalty
if not is_decoding:
seq_length = inputs.seq_length
last_idx = seq_length.cumsum(-1) - 1
split_logits = logits[last_idx, :]
else:
# most step share the same sampling parameters
split_logits = logits
split_logits = split_logits.cuda()
sampling_inputs = SamplingInputs.from_sampling_params(running)
sampling_inputs = sampling_inputs.to_device(split_logits.device)
input_ids = None
if sampling_inputs.repetition_penalty is not None:
input_ids = _gather_history(running, split_logits.device)
logits_processor = FusedLogitsProcessor(sampling_inputs)
with torch.inference_mode(), torch.cuda.stream(self.stream):
logits = logits_processor(input_ids, split_logits)
next_token_ids = logits_processor.sampling(logits)
await asyncio.get_event_loop().run_in_executor(None,
self.stream.synchronize)
next_token_ids = next_token_ids.cpu()
return next_token_ids, split_logits
@logging_timer('UpdateRunning', logger)
def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
meta: Any):
"""update scheduler."""
for token, msg in zip(next_token_ids, running):
msg.meta = meta
msg.update_token_ids(token)
msg.num_new_tokens += 1
if msg.num_new_tokens > msg.sampling_param.max_new_tokens:
msg.token_ids = torch.empty((0, ), dtype=torch.long)
if self._stopping_criteria(msg, token):
msg.status = MessageStatus.STOPPED
def _can_output_token(self, token: torch.Tensor, msg: SchedulerSequence):
"""check if output is necessary."""
if isinstance(token, torch.Tensor):
token = token.item()
stop_words = msg.sampling_param.stop_words
if stop_words is not None and token in stop_words:
return False
return True
@logging_timer('ModelForward', logger)
async def _async_model_forward(self, inputs: ModelInputs,
swap_in_map: Dict, swap_out_map: Dict):
"""model forward."""
max_prefill_token_num = self.cache_config.max_prefill_token_num
swap_done = False
class _LogitsGather:
"""logits gather."""
def __init__(self, max_seq_len):
self._max_seq_len = max_seq_len
self._start = 0
self._out_logits = None
def gather(self, output):
"""gather."""
logits = output['logits']
out_logits = self._out_logits
start = self._start
seq_len = logits.size(-2)
if out_logits is None:
out_logits = logits.new_empty(1,
self._max_seq_len,
logits.size(-1),
device='cpu')
out_logits[:, start:start + seq_len].copy_(logits,
non_blocking=True)
self._start = start + seq_len
self._out_logits = out_logits
def get_logits(self):
"""get logits."""
torch.cuda.synchronize()
return self._out_logits
async def __forward(inputs):
"""forward."""
nonlocal swap_done, swap_in_map, swap_out_map
if swap_done:
return await self.model_agent.async_forward(
inputs, swap_in_map=dict(), swap_out_map=dict())
else:
swap_done = True
return await self.model_agent.async_forward(
inputs, swap_in_map=swap_in_map, swap_out_map=swap_out_map)
async def __long_context_single_forward(inputs, index):
"""one large sequence."""
new_input = inputs.slice(index, index + 1)
max_seq_len = new_input.seq_length[0]
new_inputs = new_input.split(max_prefill_token_num,
self.cache_config.block_size)
logits_gather = _LogitsGather(max_seq_len)
for inp in new_inputs:
tmp_out = await __forward(inp)
logits_gather.gather(tmp_out)
tmp_out['logits'] = logits_gather.get_logits()
return tmp_out
async def __long_context_batched_forward(inputs, start, end):
"""batched."""
new_inputs = inputs.slice(start, end)
return await __forward(new_inputs)
async def __long_context_forward(inputs):
"""forward for long context."""
seq_len = inputs.seq_length
max_seq_len = inputs.input_ids.size(1)
batch_size = seq_len.size(0)
indices = []
token_count = 0
idx = 0
logits_gather = _LogitsGather(max_seq_len)
while idx < batch_size:
slen = seq_len[idx]
if token_count == 0 and slen > max_prefill_token_num:
tmp_out = await __long_context_single_forward(inputs, idx)
logits_gather.gather(tmp_out)
tmp_out.pop('logits', None)
idx += 1
elif token_count + slen > max_prefill_token_num:
tmp_out = await __long_context_batched_forward(
inputs, indices[0], idx)
logits_gather.gather(tmp_out)
tmp_out.pop('logits', None)
indices = []
token_count = 0
else:
indices.append(idx)
token_count += slen
idx += 1
if token_count > 0:
tmp_out = await __long_context_batched_forward(
inputs, indices[0], idx)
logits_gather.gather(tmp_out)
tmp_out['logits'] = logits_gather.get_logits()
return tmp_out
if inputs.input_ids.numel() < max_prefill_token_num:
return await __forward(inputs)
else:
return await __long_context_forward(inputs)
@logging_timer('AsyncStep', logger)
async def async_step(self, is_prefill: bool, return_logits: bool = False):
"""one step inference. Used to perform streaming chat.
Returns:
Dict[int, InferOutput]: The output of each session.
"""
# schedule
schedule_output = self.scheduler.schedule(is_prefill=is_prefill)
running: SeqList = schedule_output.running
swap_in_map = schedule_output.swap_in_map
swap_out_map = schedule_output.swap_out_map
adapters = schedule_output.adapters
if len(running) == 0:
return dict()
inputs = self.create_model_inputs(running, adapters)
logger.debug(f'<AsyncStep>: batch_size={len(running)} '
f'num_tokens={inputs.input_ids.size(-1)}')
# inference
output = await self._async_model_forward(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
custom_outputs = output['custom_outputs']
logits = output['logits']
logits = logits[0] # [bs, seq, prob] -> [seq, prob]
next_token_ids, _ = await self.async_sampling_logits(
logits, running, inputs)
self.update_running(running, next_token_ids, custom_outputs)
self.scheduler.update()
# generate output
outputs: Dict[int, InferOutput] = dict()
for idx, msg in enumerate(running):
next_id = next_token_ids[idx]
session_id = msg.session_id
if self._can_output_token(next_id, msg):
out_token_ids = [next_id.item()]
else:
out_token_ids = []
out = InferOutput(
session_id=session_id,
sender_id=msg.sender_id,
req_id=msg.req_id,
finish=(msg.status == MessageStatus.STOPPED),
token_ids=out_token_ids,
)
outputs[session_id] = out
if msg.return_logits:
start = inputs.q_start_loc[idx]
seqlen = inputs.seq_length[idx]
outputs[msg.session_id].logits = logits[start:start + seqlen]
return outputs
async def async_batched_infer(self,
session_ids: List[int],
token_ids: List[List[int]] = None,
gen_config: EngineGenerationConfig = None,
adapter_names: List[str] = None,
keep_cache: bool = False):
"""Send inference request.
Args:
session_ids (List[int]): The session id.
token_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
adapter_names (List[str]): The name of the adapters.
keep_cache (bool): Keep kv cache after infer.
Returns:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
batch_size = len(token_ids)
assert len(session_ids) == batch_size
if adapter_names is not None:
assert len(adapter_names) == batch_size
else:
adapter_names = [None for _ in range(batch_size)]
async def _add_sessions(session_ids):
for session_id in session_ids:
await self.async_add_session(session_id)
async def _add_messages(session_ids, token_ids):
add_msgs = []
sampling_param = SamplingParam.from_gen_config(gen_config)
for session_id, token_id, adapter_name in zip(
session_ids, token_ids, adapter_names):
msg = dict(token_ids=token_id,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name)
add_msgs.append(msg)
req_types = [RequestType.ADD_MESSAGE] * batch_size
req_ids = await self.req_sender.async_batched_send_async(
req_types, data=add_msgs)
return req_ids
await _add_sessions(session_ids)
req_ids = await _add_messages(session_ids, token_ids)
# receive messages
req_idx_map = dict(zip(req_ids, range(len(req_ids))))
output_token_ids = [list() for _ in req_ids]
status = 0
finish_count = batch_size
while finish_count:
if not self.req_manager.is_loop_alive():
logger.error('Engine loop is not alive.')
status = 1
break
resp = await self.req_sender.async_recv_any()
if resp.req_id not in req_ids:
continue
idx = req_idx_map[resp.req_id]
token_ids = output_token_ids[idx]
if resp.type == ResponseType.SUCCESS:
token_ids += resp.data['token_ids']
elif resp.type == ResponseType.FINISH:
token_ids += resp.data['token_ids']
if not keep_cache:
session_id = session_ids[idx]
await self.async_end_session(session_id=session_id)
finish_count -= 1
else:
logger.error(f'Unexpected response: {resp.type}')
status = 1
break
output_token_len = [len(token_ids) for token_ids in output_token_ids]
return (status, output_token_ids, output_token_len)
def batched_infer(self,
session_ids: List[int],
token_ids: List[List[int]] = None,
gen_config: EngineGenerationConfig = None,
adapter_names: List[str] = None,
keep_cache: bool = False):
"""batched infer."""
coro = self.async_batched_infer(session_ids,
token_ids,
gen_config=gen_config,
adapter_names=adapter_names,
keep_cache=keep_cache)
return self.req_sender.run_until_complete(coro)
def decode(self,
input_ids,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
adapter_names: List[str] = None):
"""Perform context decode on input tokens.
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
from torch.nn.utils.rnn import pad_sequence
logger.debug('Decoding logits.')
batch_size = len(input_ids)
def __add_messages(session_ids, input_ids, adapter_names):
add_msgs = []
sampling_param = SamplingParam(max_new_tokens=0)
for session_id, token_id, adapter_name in zip(
session_ids, input_ids, adapter_names):
msg = dict(token_ids=token_id,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name,
return_logits=True)
add_msgs.append(msg)
req_types = [RequestType.ADD_MESSAGE] * batch_size
req_ids = self.req_sender.batched_send_async(req_types,
data=add_msgs)
return req_ids
if steps is not None:
assert batch_size == len(steps)
if adapter_names is None:
adapter_names = [None] * batch_size
assert batch_size == len(adapter_names)
session_ids = tuple(range(batch_size))
if sequence_start:
for sid in session_ids:
self.req_sender.send(RequestType.END_SESSION,
dict(session_id=sid))
self.add_session(sid)
req_ids = __add_messages(session_ids, input_ids, adapter_names)
req_idx_map = dict(zip(req_ids, range(len(req_ids))))
finish_count = batch_size
ret = [None] * batch_size
while finish_count > 0:
resp = self.req_sender.recv_any()
if resp.req_id not in req_ids:
continue
assert resp.type == ResponseType.FINISH
idx = req_idx_map[resp.req_id]
ret[idx] = resp.data['logits']
finish_count -= 1
ret = pad_sequence(ret, True)
if sequence_end:
for sid in session_ids:
self.end_session(sid)
return ret
async def async_loop(self):
"""Main loop of the engine.
Each engine instance would communicate with the engine by queue.
"""
def _send_resp(step_tokens):
"""send response callback."""
for _, out in step_tokens.items():
if out.finish:
resp_type = ResponseType.FINISH
else:
resp_type = ResponseType.SUCCESS
self.req_manager.response(
Response(
type=resp_type,
sender_id=out.sender_id,
req_id=out.req_id,
data=dict(token_ids=out.token_ids, logits=out.logits),
))
prefill_interval = self.scheduler_config.prefill_interval
prefill_counter = prefill_interval
while True:
if not self.req_manager.has_requests(
) and not self.scheduler.has_unfinished():
await asyncio.sleep(0.01)
continue
self.req_manager.step()
# forward
if self.scheduler.has_unfinished():
has_running = self.scheduler.has_running()
is_prefill = not prefill_counter or not has_running
if is_prefill:
prefill_counter = prefill_interval
with torch.inference_mode():
step_tokens: Dict[int,
InferOutput] = await self.async_step(
is_prefill=is_prefill)
prefill_counter -= 1
# send response
_send_resp(step_tokens)
class EngineInstance:
"""Instance of TurboMind.
Args:
engine (Engine): engine
"""
def __init__(self, engine: Engine):
self.engine = engine
self.req_sender = engine.req_manager.build_sender()
def __del__(self):
"""Destructor."""
self.engine.req_manager.senders.pop(self.req_sender.sender_id)
async def _async_try_add_session(self, session_id: int):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
return await async_try_add_session(self.req_sender, session_id)
def _try_add_session(self, session_id: int):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
return try_add_session(self.req_sender, session_id)
async def async_stream_infer(self,
session_id: int,
input_ids: List[int],
gen_config: EngineGenerationConfig = None,
adapter_name: str = None,
**kwargs):
"""Send stream inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
adapter_name (str): The lora adapter name.
Yields:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
gen_config = gen_config or EngineGenerationConfig()
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
await async_try_add_session(self.req_sender, session_id)
msg = dict(
token_ids=input_ids,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name,
)
req_id = await self.req_sender.async_send_async(
RequestType.ADD_MESSAGE, msg)
token_ids = []
while True:
if not self.req_sender.is_loop_alive():
yield (ResponseType.ENGINE_STOP_ERROR, [], 0)
break
resp = await self.req_sender.async_recv(req_id)
if resp.req_id != req_id:
continue
if resp.type == ResponseType.SUCCESS:
token_ids += resp.data['token_ids']
yield (resp.type, token_ids, len(token_ids))
elif resp.type == ResponseType.FINISH:
token_ids += resp.data['token_ids']
yield (resp.type, token_ids, len(token_ids))
break
else:
yield (resp.type, [], 0)
break
async def async_infer(self,
session_id: int,
input_ids: List[int] = None,
gen_config: EngineGenerationConfig = None,
**kwargs):
"""Send inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
Returns:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
token_ids = []
async for outputs in self.async_stream_infer(session_id,
input_ids,
gen_config=gen_config,
**kwargs):
status, tmp_ids, _ = outputs
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return (status, token_ids, len(token_ids))
token_ids = tmp_ids
return (0, token_ids, len(token_ids))
def stream_infer(self,
session_id: int,
input_ids: List[int],
gen_config: EngineGenerationConfig = None,
adapter_name: str = None,
**kwargs):
"""Send stream inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
adapter_name (str): The lora adapter name.
Yields:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
def __call_async():
"""call async."""
coro_gen = self.async_stream_infer(session_id, input_ids,
gen_config, adapter_name,
**kwargs)
while True:
try:
yield self.req_sender.run_until_complete(
coro_gen.__anext__())
except StopAsyncIteration:
break
if not self.req_sender.is_thread_safe():
yield from __call_async()
return
gen_config = gen_config or EngineGenerationConfig()
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
try_add_session(self.req_sender, session_id)
msg = dict(
token_ids=input_ids,
session_id=session_id,
sampling_param=sampling_param,
adapter_name=adapter_name,
)
req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
token_ids = []
while True:
if not self.req_sender.is_loop_alive():
yield (ResponseType.ENGINE_STOP_ERROR, [], 0)
break
resp = self.req_sender.recv(req_id)
if resp.req_id != req_id:
continue
if resp.type == ResponseType.SUCCESS:
token_ids += resp.data['token_ids']
yield (resp.type, token_ids, len(token_ids))
elif resp.type == ResponseType.FINISH:
token_ids += resp.data['token_ids']
yield (resp.type, token_ids, len(token_ids))
break
else:
yield (resp.type, [], 0)
break
def infer(self,
session_id: int,
input_ids: List[int] = None,
gen_config: EngineGenerationConfig = None,
**kwargs):
"""Send inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
Returns:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
token_ids = []
for outputs in self.stream_infer(session_id,
input_ids,
gen_config=gen_config,
**kwargs):
status, tmp_ids, _ = outputs
if status not in [ResponseType.SUCCESS, ResponseType.FINISH]:
return (status, token_ids, len(token_ids))
token_ids = tmp_ids
return (0, token_ids, len(token_ids))
async def async_end(self, session_id: int):
"""End the given session."""
return await async_end(self.req_sender, session_id)
def end(self, session_id: int):
"""End the given session."""
return end(self.req_sender, session_id)
async def async_cancel(self, session_id: int):
"""Stop current streaming inference."""
return await async_cancel(self.req_sender, session_id)
def cancel(self, session_id: int):
"""Stop current streaming inference."""
return cancel(self.req_sender, session_id)
def decode(self,
input_ids,
steps: List[int] = None,
sequence_start: bool = True,
sequence_end: bool = True,
adapter_names: List[str] = None):
"""Perform context decode on input tokens.
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
return self.engine.decode(input_ids,
steps=steps,
sequence_start=sequence_start,
sequence_end=sequence_end,
adapter_names=adapter_names)
# Copyright (c) OpenMMLab. All rights reserved.
from dataclasses import asdict, dataclass
from typing import Dict, List
import torch
from transformers.generation.logits_process import LogitsWarper
from ..messages import SchedulerSequence
def _process_temperature(scores: torch.Tensor,
temperature: torch.Tensor,
inplace: bool = True):
"""process temperature."""
temperature = temperature.to(scores.dtype)
if not inplace:
scores = scores / temperature[:, None]
else:
scores /= temperature[:, None]
return scores
def _process_bad_words(scores: torch.Tensor,
bad_words: torch.LongTensor,
filter_value: float = -float('inf'),
inplace: bool = True):
"""process bad words."""
batch_size = scores.size(0)
batch_idx = torch.arange(batch_size, device=scores.device)
filtered_scores = scores[batch_idx[:, None], bad_words]
filtered_scores[bad_words >= 0] = filter_value
if not inplace:
scores = scores.clone()
scores[batch_idx[:, None], bad_words] = filtered_scores
return scores
def _process_repetition_penalty(scores: torch.Tensor,
input_ids: torch.LongTensor,
penalty: torch.Tensor,
inplace: bool = True):
"""process repetition penalty."""
score = torch.gather(scores, 1, input_ids)
penalty = penalty.to(score.dtype)
score = torch.where(score < 0, score * penalty[:, None],
score / penalty[:, None])
if not inplace:
scores = scores.clone()
scores.scatter_(1, input_ids, score)
return scores
def _filter_topk_sorted(scores: torch.Tensor,
topk: torch.LongTensor,
filter_value: float = -float('inf'),
inplace: bool = True):
"""filter topk on sorted scores."""
filter_value = -float('inf')
num_tokens = scores.size(1)
token_idx = torch.arange(num_tokens, device=scores.device)
mask = token_idx[None, :] >= topk[:, None]
if inplace:
scores.masked_fill_(mask, filter_value)
else:
scores = scores.masked_fill(mask, filter_value)
return scores
def _filter_topp_sorted(scores: torch.Tensor,
topp: torch.Tensor,
filter_value: float = -float('inf'),
inplace: bool = True):
"""filter topp on sorted scores."""
softmax_scores = scores.softmax(-1)
cum_scores = softmax_scores.cumsum(1) - softmax_scores
mask = cum_scores > topp[:, None]
mask[:, 0] = False # keep at least one
if inplace:
scores.masked_fill_(mask, filter_value)
else:
scores = scores.masked_fill(mask, filter_value)
return scores
def _multinomial_sampling(scores: torch.Tensor,
seeds: torch.LongTensor,
offsets: torch.LongTensor,
indices: torch.LongTensor = None):
"""sampling."""
from lmdeploy.pytorch.kernels import multinomial_sampling
return multinomial_sampling(scores, seeds, offsets, indices)
@dataclass
class SamplingInputs:
temperature: torch.Tensor = None
bad_words: torch.LongTensor = None
repetition_penalty: torch.Tensor = None
top_k: torch.LongTensor = None
top_p: torch.Tensor = None
random_seeds: int = None
random_offsets: int = None
max_top_k: int = 1
min_top_p: float = 1.0
@classmethod
def from_sampling_params(cls, seqs: List[SchedulerSequence]):
"""from samplingg params."""
batch_size = len(seqs)
temperature = [None] * batch_size
repetition_penalty = [None] * batch_size
top_k = [None] * batch_size
top_p = [None] * batch_size
bad_words = [None] * batch_size
random_seeds = [torch.seed() & 0xffffffff] * batch_size
random_offsets = [None] * batch_size
def __gather_params():
"""gather params."""
for idx, seq in enumerate(seqs):
param = seq.sampling_param
temperature[idx] = param.temperature
repetition_penalty[idx] = param.repetition_penalty
top_k[idx] = param.top_k
top_p[idx] = param.top_p
random_offsets[idx] = seq.random_offsets
if param.random_seed is not None:
random_seeds[idx] = param.random_seed & 0xffffffff
bw = param.bad_words
if (not param.ignore_eos
and seq.num_new_tokens < param.min_new_tokens):
bw = bw + param.stop_words
bad_words[idx] = bw
def __get_topp(top_p):
"""get topp."""
min_top_p = min(top_p)
if min_top_p == 1.0:
top_p = None
else:
top_p = torch.tensor(top_p)
return top_p, min_top_p
def __get_bad_words(bad_words, max_bw_len):
"""get bad words."""
ret = torch.full((batch_size, max_bw_len), -1, dtype=torch.int64)
for idx, bw in enumerate(bad_words):
bw_len = len(bw)
if bw_len == 0:
continue
bw = ret.new_tensor(bw)
ret[idx, :bw_len] = bw
return ret
__gather_params()
if all(rp == 1.0 for rp in repetition_penalty):
repetition_penalty = None
else:
repetition_penalty = torch.tensor(repetition_penalty)
temperature = torch.tensor(temperature)
max_bw_len = max(len(bw) for bw in bad_words)
if max_bw_len == 0:
bad_words = None
else:
if all(len(bw) == max_bw_len for bw in bad_words):
bad_words = torch.tensor(bad_words)
else:
bad_words = __get_bad_words(bad_words, max_bw_len)
max_top_k = max(top_k)
if max_top_k == 1:
top_k = None
top_p, min_top_p = None, 1.0
random_seeds = None
random_offsets = None
else:
top_k = torch.tensor(top_k)
top_p, min_top_p = __get_topp(top_p)
random_seeds = torch.tensor(random_seeds)
random_offsets = torch.tensor(random_offsets)
sampling_input = cls(
temperature=temperature,
bad_words=bad_words,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
random_seeds=random_seeds,
random_offsets=random_offsets,
max_top_k=max_top_k,
min_top_p=min_top_p,
)
return sampling_input
def to_device(self, device: str):
"""to device."""
input_dict = asdict(self)
out_dict = dict()
for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
v = v.to(device)
out_dict[k] = v
return SamplingInputs(**out_dict)
class SeedManager:
"""random seed manager."""
def __init__(self):
self._generators: Dict[int, torch.Generator] = dict()
def new_generator(self, seed: int, device: str = 'cuda'):
"""new generator."""
return torch.Generator(device=device).manual_seed(seed)
def get(self, seed: int, device: str = 'cuda'):
"""get generator."""
if seed not in self._generators:
generator = self.new_generator(seed, device)
self._generators[seed] = generator
return self._generators[seed]
SEED_MANAGER = SeedManager()
class FusedLogitsProcessor(LogitsWarper):
"""Custom logits processor."""
def __init__(self, sampling_inputs: SamplingInputs):
self.sampling_inputs: SamplingInputs = sampling_inputs
def __call__(self, input_ids: torch.LongTensor,
scores: torch.FloatTensor) -> torch.FloatTensor:
r"""
Args:
input_ids (torch.LongTensor):
Indices of input sequence tokens in the vocabulary.
scores (torch.FloatTensor):
Prediction scores of a language modeling head.
These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token
when using beam search
Return:
torch.FloatTensor: The processed prediction scores.
"""
sampling_inputs = self.sampling_inputs
scores = scores.clone()
repetition_penalty = sampling_inputs.repetition_penalty
if repetition_penalty is not None:
scores = _process_repetition_penalty(scores, input_ids,
repetition_penalty)
temperature = sampling_inputs.temperature
if temperature is not None:
scores = _process_temperature(scores, temperature)
bad_words = sampling_inputs.bad_words
if bad_words is not None:
scores = _process_bad_words(scores, bad_words)
return scores
def sampling(self, logits: torch.Tensor):
"""sampling."""
sampling_inputs = self.sampling_inputs
def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor):
"""random sampling."""
top_k = sampling_inputs.top_k
if top_k is not None:
scores = _filter_topk_sorted(scores, top_k)
top_p = sampling_inputs.top_p
if top_p is not None:
scores = _filter_topp_sorted(scores, top_p)
softmax_scores = scores.softmax(1)
seeds = sampling_inputs.random_seeds
offsets = sampling_inputs.random_offsets
return _multinomial_sampling(softmax_scores, seeds, offsets,
indices)
if sampling_inputs.max_top_k == 1:
return logits.argmax(-1)
else:
scores, indices = logits.sort(1, descending=True)
return __random_sampling(scores, indices)
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import os
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Union
import torch
import torch.distributed as dist
from torch import multiprocessing as mp
from torch.distributed._tensor import DeviceMesh, Replicate, distribute_tensor
from transformers import AutoModelForCausalLM
from lmdeploy.pytorch.accel import LoadNoInit
from lmdeploy.utils import get_logger
from ..adapter.adapter import (AdapterWeightMap, get_indexed_lora_linears,
get_max_lora_weight_size, update_lora_linears)
from ..config import CacheConfig, ModelConfig
from ..models import patch
from ..utils import get_gpu_memory
from .cache_engine import CacheEngine
logger = get_logger('lmdeploy')
_PATCH_ARG_NAMES = ['context', 'use_origin']
def _infer_block_size(model: torch.nn.Module,
model_config: ModelConfig,
cache_config: CacheConfig,
world_size: int = 1):
"""infer block size."""
max_weight_dim = get_max_lora_weight_size(model)
if max_weight_dim == 0:
return cache_config.block_size
per_token_size = model_config.get_head_size(
) * model_config.num_key_value_heads // world_size
block_size = 1
while block_size * per_token_size < max_weight_dim:
block_size *= 2
return block_size * world_size
def _update_cache_config(model_config: ModelConfig,
cache_config: CacheConfig,
gpu_id: int = 0,
host_mem_size: int = 4 * (1 << 30),
world_size: int = 1):
"""Update the gpu mem and cpu mem according to model info.
Args:
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
gpu_id (int): The GPU id to use.
"""
def __get_free_gpu_mem_size():
"""get free gpu memory size."""
torch.cuda.empty_cache()
gpu_mem_physical_free, _ = get_gpu_memory(gpu_id)
logger.debug(f'device<{gpu_id}> free gpu memory:'
f' {gpu_mem_physical_free>>20} mb')
vocal_size = model_config.vocab_size
max_prefill_token_num = cache_config.max_prefill_token_num
# lm_head output(2) + to float(4) + estimated misc(1) = 7
intermediate_cache_size = int(max_prefill_token_num * vocal_size * 7)
logger.debug('estimated max runtime memory:'
f' {intermediate_cache_size>>20} mb')
gpu_mem_physical_free -= intermediate_cache_size
return gpu_mem_physical_free * cache_config.cache_max_entry_count
gpu_mem = __get_free_gpu_mem_size()
cpu_mem = host_mem_size
cache_block_size = CacheEngine.get_cache_block_size(
cache_config.block_size, model_config, world_size)
if cache_config.num_cpu_blocks == 0:
cache_config.num_cpu_blocks = int(cpu_mem / cache_block_size)
if cache_config.num_gpu_blocks == 0:
cache_config.num_gpu_blocks = int(gpu_mem / cache_block_size)
cache_config.window_size = model_config.sliding_window
logger.debug('block num: {}'.format(cache_config.num_gpu_blocks))
@dataclass
class ModelInputs:
"""Input of the model."""
input_ids: torch.LongTensor
seq_length: torch.LongTensor
attention_mask: torch.Tensor
block_offsets: torch.LongTensor
position_ids: torch.LongTensor
q_start_loc: torch.LongTensor
history_lengths: List[int]
is_decoding: bool
local_adapter_ids: torch.LongTensor = None
global_adapter_ids: torch.LongTensor = None
adapter_offsets: torch.LongTensor = None
max_rank: int = 0
meta: Any = None
def slice(self, start: int, end: int):
"""select by indices."""
sli = slice(start, end)
start_loc = self.q_start_loc[sli]
seq_length = self.seq_length[sli]
end_loc = start_loc[-1] + seq_length[-1]
input_ids = self.input_ids[:, start_loc[0]:end_loc]
start_loc = start_loc - start_loc[0]
history_lengths = self.history_lengths[sli]
local_adapter_ids = self.local_adapter_ids
if local_adapter_ids is not None:
local_adapter_ids = local_adapter_ids[sli]
return ModelInputs(input_ids=input_ids,
seq_length=seq_length,
attention_mask=self.attention_mask[sli],
block_offsets=self.block_offsets[sli],
position_ids=self.position_ids[sli],
q_start_loc=start_loc,
history_lengths=history_lengths,
is_decoding=self.is_decoding,
local_adapter_ids=local_adapter_ids,
global_adapter_ids=self.global_adapter_ids,
adapter_offsets=self.adapter_offsets,
max_rank=self.max_rank,
meta=self.meta)
def split(self, split_size: int, block_size: int):
"""split inputs."""
assert len(
self.seq_length) == 1, ('Can not perform split on batched input.')
assert split_size % block_size == 0, (
'split_size should be multi of block_size.')
input_ids = self.input_ids
if input_ids.numel() < split_size:
return self
num_blocks = split_size // block_size
overlap = (self.history_lengths[0] % block_size != 0)
max_seq_len = self.seq_length[0].item()
ret = []
block_start = 0
history_len = self.history_lengths[0]
for i in range(0, max_seq_len, split_size):
start = i
end = min(max_seq_len, i + split_size)
block_end = block_start + num_blocks
if overlap:
block_end += 1
local_adapter_ids = self.local_adapter_ids
if local_adapter_ids is not None:
local_adapter_ids = local_adapter_ids[:, start:end]
inp = ModelInputs(
input_ids=self.input_ids[:, start:end],
seq_length=input_ids.new_tensor([end - start]),
attention_mask=self.attention_mask[:, start:end],
block_offsets=self.block_offsets[:, :block_end],
position_ids=self.position_ids[:, start:end],
q_start_loc=input_ids.new_zeros(1),
history_lengths=[history_len + start],
is_decoding=self.is_decoding,
local_adapter_ids=local_adapter_ids,
global_adapter_ids=self.global_adapter_ids,
adapter_offsets=self.adapter_offsets,
max_rank=self.max_rank,
meta=self.meta,
)
ret.append(inp)
block_start += num_blocks
return ret
def to_device(self, device: str):
"""to device."""
input_dict = asdict(self)
out_dict = dict()
for k, v in input_dict.items():
if isinstance(v, torch.Tensor):
v = v.to(device)
out_dict[k] = v
return ModelInputs(**out_dict)
@dataclass
class StepContext:
"""context of Model.
patched model might need extra information to perform inference. This
dataclass provide these infos and tools.
"""
inputs: ModelInputs
block_offsets: torch.LongTensor
position_ids: torch.LongTensor
position_ids_1d: torch.LongTensor
q_start_loc: torch.LongTensor
history_lengths: torch.LongTensor
q_seq_length: torch.LongTensor
kv_seq_length: torch.LongTensor
max_q_seq_length: int
max_kv_seq_length: int
kv_caches: List
is_decoding: bool
world_size: int = 1
json_config: Dict = None
local_adapter_ids: torch.LongTensor = None
global_adapter_ids: torch.LongTensor = None
adapter_offsets: torch.LongTensor = None
max_rank: int = 0
_outputs: Dict = field(default_factory=dict)
@classmethod
def new(
cls,
inputs: ModelInputs,
world_size: int = 1,
device: str = 'cuda',
json_config: dict = None,
kv_caches: List = None,
):
"""build step context.
Args:
inputs (ModelInputs): packaged model inputs.
world_size (int): The distribution world size.
device (str): The device of the tensors.
"""
position_ids = inputs.position_ids
max_q_seq_length = position_ids.size(-1)
# seq_len + history_length
kv_seq_length = position_ids[..., -1] + 1
# position ids 1d
q_seq_length = inputs.seq_length
position_ids_1d = cls.get_position_ids_1d(position_ids, q_seq_length,
device)
max_kv_seq_length = max_q_seq_length + max(inputs.history_lengths)
ret = StepContext(inputs=inputs,
block_offsets=inputs.block_offsets,
position_ids=inputs.position_ids,
position_ids_1d=position_ids_1d,
q_start_loc=inputs.q_start_loc,
history_lengths=inputs.history_lengths,
q_seq_length=inputs.seq_length,
kv_seq_length=kv_seq_length,
max_q_seq_length=max_q_seq_length,
max_kv_seq_length=max_kv_seq_length,
kv_caches=kv_caches,
is_decoding=inputs.is_decoding,
world_size=world_size,
json_config=json_config,
local_adapter_ids=inputs.local_adapter_ids,
global_adapter_ids=inputs.global_adapter_ids,
adapter_offsets=inputs.adapter_offsets,
max_rank=inputs.max_rank)
return ret
@classmethod
def tensorlize_block_offsets(cls, block_offsets, device):
"""tensorlize block_offsets."""
import numpy as np
offset_len = [len(offset) for offset in block_offsets]
max_offsets_len = max(offset_len)
batch_size = len(offset_len)
pad_block_offsets = np.zeros((batch_size, max_offsets_len),
dtype=np.int64)
for pad_offset, offset, off_len in zip(pad_block_offsets,
block_offsets, offset_len):
pad_offset[:off_len] = offset
block_offsets = torch.from_numpy(pad_block_offsets).to(device)
return block_offsets
@classmethod
def get_position_ids_1d(cls,
position_ids: torch.LongTensor,
seq_length: torch.LongTensor,
device: str = 'cuda'):
"""get 1d position_ids."""
if position_ids.size(1) == 1:
position_ids_1d = position_ids.flatten()
else:
position_ids_1d = [
ids[:l] for ids, l in zip(position_ids.cpu(), seq_length.cpu())
]
position_ids_1d = torch.cat(position_ids_1d).to(device)
return position_ids_1d
def get_block_offsets(self):
"""return block offsets."""
return self.block_offsets
def set_output(self, key, value):
"""set output."""
self._outputs[key] = value
def get_output(self, key):
"""get output."""
if key in self._outputs:
return self._outputs[key]
return None
def cache_swapping(cache_engine: CacheEngine, swap_in_map: dict,
swap_out_map: dict):
"""perform cache swapping."""
issued_cache_op = False
if len(swap_in_map) > 0:
cache_engine.swap_in(swap_in_map)
issued_cache_op = True
if len(swap_out_map) > 0:
cache_engine.swap_out(swap_out_map)
issued_cache_op = True
if issued_cache_op:
cache_events = cache_engine.events
for event in cache_events:
event.wait()
def model_forward(
patched_model: torch.nn.Module,
inputs: ModelInputs,
cache_engine: CacheEngine,
json_config: dict = None,
world_size: int = 1,
stream: torch.cuda.Stream = None,
):
"""perform model forward."""
stream = stream or torch.cuda.current_stream()
with torch.inference_mode(), torch.cuda.stream(stream):
# forward
inputs = inputs.to_device('cuda')
context = StepContext.new(
inputs=inputs,
world_size=world_size,
json_config=json_config,
kv_caches=cache_engine.gpu_cache,
)
output = patched_model.patched_forward(
input_ids=inputs.input_ids,
position_ids=inputs.position_ids,
attention_mask=inputs.attention_mask,
past_key_values=cache_engine.gpu_cache,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
use_origin=False,
context=context,
)
return dict(logits=output['logits'], custom_outputs=context._outputs)
def _load_adapters(hf_model: torch.nn.Module,
adapters: Dict[str, str],
device_map: str = 'cpu'):
"""load adapters."""
if not adapters:
return
for name, path in adapters.items():
logger.info(f'load adapter <{name}> from "{path}".')
hf_model.load_adapter(path, name, device_map=device_map)
def _add_adapters(hf_model: torch.nn.Module, adapters: Dict[str, str]):
"""add adapters."""
if not adapters:
return
from peft import PeftConfig, inject_adapter_in_model
for name, path in adapters.items():
config = PeftConfig.from_pretrained(path)
inject_adapter_in_model(config, model=hf_model, adapter_name=name)
def _unparam_lora_weight(model: torch.nn.Module):
"""unparam lora weight.
We don't want to move weight of lora to gpu.
"""
from peft.tuners.lora import Linear as LoRALinear
def _tensorize_weight(linear):
"""tensorize weight."""
w = linear.weight
del linear.weight
linear.weight = w.data
for _, mod in model.named_modules():
if isinstance(mod, LoRALinear):
lora_A = mod.lora_A
lora_B = mod.lora_B
for linear in lora_A.values():
_tensorize_weight(linear)
for linear in lora_B.values():
_tensorize_weight(linear)
SwapMap = Dict[int, int]
class AutoModelAgent:
"""Base model agent."""
def __init__(self, model_config: ModelConfig, cache_config: CacheConfig):
self.model_config = model_config
self.cache_config = cache_config
def paging_adapters(self, weight_maps: List[AdapterWeightMap]):
"""paging adapter."""
raise NotImplementedError('Not implemented.')
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
raise NotImplementedError('Not implemented.')
def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
raise NotImplementedError('Not implemented.')
@classmethod
def from_pretrained(cls,
pretrained_model_name_or_path: str,
cache_config: CacheConfig,
trust_remote_code: bool,
adapters: Dict[str, str] = None,
tp: int = 1):
"""from pretrained."""
return build_model_agent(pretrained_model_name_or_path,
cache_config=cache_config,
trust_remote_code=trust_remote_code,
adapters=adapters,
tp=tp)
class BaseModelAgent(AutoModelAgent):
"""Base model agent.
load model on local gpu
Args:
model_path (str): The hugging face model path.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
trust_remote_code (bool): Trust remote code
"""
def __init__(self,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
adapters: Dict[str, str] = None,
trust_remote_code: bool = True):
super().__init__(model_config=model_config, cache_config=cache_config)
torch_dtype = model_config.dtype
self.patched_model = self._build_model(
model_path,
torch_dtype=torch_dtype,
adapters=adapters,
trust_remote_code=trust_remote_code)
block_size = _infer_block_size(self.patched_model, model_config,
cache_config)
if block_size != cache_config.block_size:
cache_config.block_size = block_size
logger.warning(f'infered block size: {block_size}')
_update_cache_config(model_config, cache_config)
self.cache_engine = CacheEngine(cache_config, model_config)
self.stream = torch.cuda.Stream()
def _build_model(self,
model_path: str,
torch_dtype: torch.dtype,
adapters: Dict[str, str] = None,
trust_remote_code: bool = True):
"""build patched model."""
with LoadNoInit():
hf_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
**self.model_config.init_kwargs)
hf_model.eval()
hf_model.config.use_cache = True
if adapters:
_load_adapters(hf_model, adapters)
patched_model = patch(hf_model, _PATCH_ARG_NAMES)
if adapters:
_unparam_lora_weight(patched_model)
patched_model = patched_model.cuda()
return patched_model
def paging_adapters(self, weight_maps: List[AdapterWeightMap]):
"""paging adapter."""
logger.info('paging adapters.')
lora_linears = get_indexed_lora_linears(self.patched_model)
cpu_caches = self.cache_engine.cpu_cache
num_blocks = self.cache_engine.num_cpu_blocks
cpu_caches = [(kcache.view(num_blocks,
-1), vcache.view(num_blocks, -1))
for kcache, vcache in cpu_caches]
for weight_map in weight_maps:
weight_map.cache_adapter(lora_linears, cpu_caches)
update_lora_linears(lora_linears, weight_maps, device='cuda')
def _forward_impl(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
cache_swapping(self.cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
output = model_forward(
self.patched_model,
inputs,
self.cache_engine,
self.model_config.json_config,
world_size=1,
stream=self.stream,
)
return output
def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output = self._forward_impl(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
self.stream.synchronize()
return output
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output = self._forward_impl(inputs,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
await asyncio.get_event_loop().run_in_executor(None,
self.stream.synchronize)
return output
@dataclass
class TPResponse:
ret_code: int
error: Union[Exception, List[Exception]] = None
data: Any = None
def gather_error(self):
"""gather error."""
rank = dist.get_rank()
world_size = dist.get_world_size()
# gather errors
error_count = torch.tensor(self.ret_code).cuda(rank)
dist.all_reduce(error_count)
if error_count.item() > 0:
all_errors = [None] * world_size
dist.all_gather_object(all_errors, self.error)
self.ret_code = 1
self.error = all_errors
def raise_error(self, default_error: Exception):
"""raise error."""
if self.error is None:
raise default_error
elif isinstance(self.error, Exception):
raise self.error
else:
assert isinstance(self.error, List), ('expect error type list, '
f'got {type(self.error)}')
rank = dist.get_rank()
err = self.error[rank]
if err is None:
raise default_error
else:
raise err
def _get_model_memory_usage(model: torch.nn.Module) -> int:
"""get model memory usage."""
size = 0
for _, param in model.named_parameters():
size += param.element_size() * param.numel()
for _, buf in model.named_buffers():
size += buf.element_size() * param.numel()
return size
def _create_device_map(model: torch.nn.Module,
world_size: int,
device_map: dict = None):
"""Distribute params to each devices."""
if device_map is None:
device_map = dict()
device_id = 0
for name, _ in model.named_parameters():
device_map[name] = device_id
device_id = (device_id + 1) % world_size
for name, _ in model.named_buffers():
device_map[name] = device_id
device_id = (device_id + 1) % world_size
return device_map
def _tp_build_model(
rank: int,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
adapters: Dict[str, str],
out_que: mp.Queue,
world_size: int,
trust_remote_code=True,
):
"""build tensor parallel model."""
from accelerate import init_empty_weights
error_code = 0
error_type = None
patched_model = None
cache_engine = None
def __get_device_map(model, device_map=None):
"""get device map of model."""
import psutil
model_size = _get_model_memory_usage(model)
if psutil.virtual_memory().available < model_size:
logger.debug('Preload model on GPU.')
return device_map
else:
logger.debug('Preload model on CPU.')
return 'cpu'
def __load_params_and_buffers(param_mod, mod):
"""load param and buffer."""
for name, param in param_mod.named_parameters(recurse=False):
mod.register_parameter(name, param)
for name, buffer in param_mod.named_buffers(recurse=False):
mod.register_buffer(name, buffer)
def __load_state_dict_assign(param_model, model):
"""load state dict assign."""
try:
model.load_state_dict(param_model.state_dict(), assign=True)
except Exception:
__load_params_and_buffers(param_model, model)
mods = dict(model.named_modules())
for mod_name, param_mod in param_model.named_modules():
mod = mods[mod_name]
__load_params_and_buffers(param_mod, mod)
def _broadcast_config(cache_config):
"""broadcast cache config, use minimum cache."""
if rank == 0:
gathered_configs = [None] * world_size
dist.gather_object(cache_config, gathered_configs)
num_gpu_blocks_list = [
config.num_gpu_blocks for config in gathered_configs
]
num_cpu_blocks_list = [
config.num_cpu_blocks for config in gathered_configs
]
min_num_gpu_blocks = min(num_gpu_blocks_list)
min_num_cpu_blocks = min(num_cpu_blocks_list)
cache_config.num_cpu_blocks = min_num_cpu_blocks
cache_config.num_gpu_blocks = min_num_gpu_blocks
config_list = [cache_config]
else:
gathered_configs = None
dist.gather_object(cache_config, gathered_configs)
config_list = [None]
dist.broadcast_object_list(config_list)
return config_list[0]
try:
config = model_config.hf_config
torch_dtype = model_config.dtype
device_map = None
with init_empty_weights():
model = AutoModelForCausalLM.from_config(
config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
**model_config.init_kwargs)
if rank == 0:
device_map = _create_device_map(model, world_size)
_add_adapters(model, adapters)
if rank == 0:
# adapter would remove weight of linear.
device_map = _create_device_map(model, world_size, device_map)
model.eval()
model.config.use_cache = True
if rank == 0:
with LoadNoInit():
device_map = __get_device_map(model, device_map)
param_model = AutoModelForCausalLM.from_pretrained(
model_path,
torch_dtype=torch_dtype,
device_map=device_map,
trust_remote_code=trust_remote_code,
**model_config.init_kwargs)
_load_adapters(param_model, adapters, device_map=device_map)
__load_state_dict_assign(param_model, model)
param_model = param_model.to('meta')
del param_model
patched_model = patch(
model,
extra_args=_PATCH_ARG_NAMES,
rank=rank,
world_size=world_size,
)
block_size = _infer_block_size(patched_model, model_config,
cache_config, world_size)
if block_size != cache_config.block_size:
cache_config.block_size = block_size
if rank == 0:
logger.warning(f'infered block size: {block_size}')
_update_cache_config(model_config,
cache_config,
gpu_id=rank,
world_size=world_size)
cache_config = _broadcast_config(cache_config)
cache_engine = CacheEngine(cache_config,
model_config,
rank=rank,
world_size=world_size)
except Exception as e:
logger.error(f'rank[{rank}] failed with error: {e}')
error_code = 1
error_type = e
# response
resp = TPResponse(error_code, error_type, cache_config)
resp.gather_error()
if rank == 0:
out_que.put(resp)
if resp.ret_code != 0:
resp.raise_error(RuntimeError('failed to init model.'))
return patched_model, cache_engine
def _tp_get_input(rank: int, in_que: mp.Queue, world_size: int):
"""get input tensor parallel."""
device_mesh = DeviceMesh('cuda', list(range(world_size)))
# broadcast meta info
if rank == 0:
inputs, swap_in_map, swap_out_map = in_que.get()
inputs = asdict(inputs)
input_tensors = dict(
(k, v) for k, v in inputs.items() if isinstance(v, torch.Tensor))
tensor_metas = dict(
(name, (t.shape, t.dtype)) for name, t in input_tensors.items())
other_metas = dict((k, v) for k, v in inputs.items()
if not isinstance(v, torch.Tensor))
input_metas = (tensor_metas, other_metas)
objs = [input_metas, swap_in_map, swap_out_map]
else:
objs = [None, None, None]
dist.broadcast_object_list(objs)
if rank != 0:
input_metas = objs[0]
tensor_metas, other_metas = input_metas
input_tensors = dict((name, torch.empty(meta[0], dtype=meta[1]))
for name, meta in tensor_metas.items())
updated_inputs = dict()
for name, t in input_tensors.items():
updated_inputs[name] = distribute_tensor(t,
device_mesh=device_mesh,
placements=[Replicate()
]).to_local()
torch.cuda.synchronize()
inputs = updated_inputs
inputs.update(other_metas)
inputs = ModelInputs(**inputs)
swap_in_map = objs[1]
swap_out_map = objs[2]
return inputs, swap_in_map, swap_out_map
def _tp_paging_adapters(
rank: int,
patched_model: torch.nn.Module,
cache_engine: CacheEngine,
in_que: mp.Queue,
out_que: mp.Queue,
):
"""tp paging adapters."""
def __get_weight_map():
"""get weight map."""
if rank == 0:
weight_maps = in_que.get()
dist_obj = [weight_maps]
else:
dist_obj = [None]
dist.broadcast_object_list(dist_obj)
return dist_obj[0]
def __paging(weight_maps):
"""paging."""
lora_linears = get_indexed_lora_linears(patched_model)
cpu_caches = cache_engine.cpu_cache
num_blocks = cache_engine.num_cpu_blocks
cpu_caches = [(kcache.view(num_blocks,
-1), vcache.view(num_blocks, -1))
for kcache, vcache in cpu_caches]
for weight_map in weight_maps:
weight_map.cache_adapter(lora_linears, cpu_caches)
update_lora_linears(lora_linears, weight_maps, device='cuda')
weight_maps = __get_weight_map()
resp = TPResponse(0)
try:
if rank == 0:
logger.info('tp paging adapters.')
if len(weight_maps) > 0:
__paging(weight_maps)
except Exception as e:
resp.ret_code = 1
resp.error = e
resp.gather_error()
if rank == 0:
out_que.put(resp)
if resp.ret_code != 0:
resp.raise_error(RuntimeError('tp paging adapters failed.'))
def _tp_model_loop(
rank: int,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
adapters: Dict[str, str],
in_que: mp.Queue,
out_que: mp.Queue,
world_size: int,
trust_remote_code=True,
):
"""Start model loops for tensor parallel model inference.
Args:
rank (int): Distribution rank.
model_path (int): Path of the hugging face model. Could be
local or online.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache.
in_que (mp.Queue): Input queue. Used to receive model input.
out_que (mp.Queue): Output queue. Used to send the model output.
world_size (int): The distribution world size.
"""
stream = torch.cuda.Stream()
patched_model, cache_engine = _tp_build_model(
rank,
model_path,
model_config,
cache_config,
adapters,
out_que=out_que,
world_size=world_size,
trust_remote_code=trust_remote_code)
if adapters:
_tp_paging_adapters(rank,
patched_model,
cache_engine=cache_engine,
in_que=in_que,
out_que=out_que)
while True:
inputs, swap_in_map, swap_out_map = _tp_get_input(
rank, in_que, world_size)
cache_swapping(cache_engine,
swap_in_map=swap_in_map,
swap_out_map=swap_out_map)
output = model_forward(
patched_model,
inputs,
cache_engine,
model_config.json_config,
world_size=world_size,
stream=stream,
)
stream.synchronize()
if rank == 0:
resp_output = output
out_que.put(TPResponse(0, None, resp_output))
def _start_tp_process(rank: int,
world_size: int,
func: Callable,
args: List = None,
kwargs: Dict = None,
port: int = 29500):
"""Start the tensor parallel process.
Args:
rank (int): The distribution rank.
world_size (int): The distribution world size.
func (Callable): The function to be called in the process.
args (List): The arguments of the func.
kwargs (Dict): The keyword arguments of the func.
"""
try:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(port)
dist.init_process_group('nccl', rank=rank, world_size=world_size)
with torch.cuda.device(rank), torch.no_grad():
args = args or tuple()
kwargs = kwargs or dict()
func(rank, *args, **kwargs)
except Exception as e:
from traceback import print_exc
logger.error(f'Rank[{rank}] failed.')
print_exc()
raise e
def _check_context_alive(mp_context: mp.ProcessContext):
"""check context alive."""
procs = mp_context.processes
for idx, p in enumerate(procs):
if not p.is_alive():
raise RuntimeError(f'Rank[{idx}] failed.')
def _queue_get_response(que: mp.Queue,
mp_context: mp.ProcessContext,
interval: float = 1.0):
"""get response."""
from multiprocessing.queues import Empty
while True:
try:
return que.get(timeout=interval)
except Empty:
_check_context_alive(mp_context)
async def _async_queue_get_response(que: mp.Queue,
mp_context: mp.ProcessContext,
interval: float = 1.0):
"""get response."""
from multiprocessing.queues import Empty
def __try_que_get():
"""try que get."""
try:
return que.get(timeout=interval)
except Empty:
return None
while True:
ret = await asyncio.get_event_loop().run_in_executor(
None, __try_que_get)
if ret is not None:
return ret
_check_context_alive(mp_context)
class TPModelAgent(AutoModelAgent):
"""Tensor Parallelism model agent.
load model on multiple GPUs
Args:
model_path (str): The hugging face model path.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
trust_remote_code (bool): Trust remote code
"""
def __init__(self,
model_path: str,
model_config: ModelConfig,
cache_config: CacheConfig,
world_size: int,
adapters: Dict[str, str] = None,
trust_remote_code: bool = True) -> None:
self.mp_ctx = mp.get_context('spawn')
super().__init__(model_config=model_config, cache_config=cache_config)
self.world_size = world_size
self.tp_model_in_que = self.mp_ctx.Queue(10)
self.tp_model_out_que = self.mp_ctx.Queue(10)
self.patch_model_tp(model_path,
model_config=model_config,
cache_config=cache_config,
adapters=adapters,
in_que=self.tp_model_in_que,
out_que=self.tp_model_out_que,
world_size=world_size,
trust_remote_code=trust_remote_code)
def patch_model_tp(self, model_path: str, model_config: ModelConfig,
cache_config: CacheConfig, adapters: Dict[str, str],
in_que: mp.Queue, out_que: mp.Queue, world_size: int,
trust_remote_code: bool):
"""Start tensor parallel sub process.
Args:
model_path (int): Path of the hugging face model.
Could be local or online.
extra_args (List[str]): The extra arguments to add to the
patched model.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache.
in_que (mp.Queue): Input queue. Used to receive model input.
out_que (mp.Queue): Output queue. Used to send the model output.
world_size (int): The distribution world size.
"""
def __find_available_port() -> bool:
"""find available port."""
import socket
port = 29500
while True:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
if s.connect_ex(('localhost', port)) != 0:
return port
port += 1
self.mp_context = mp.spawn(
_start_tp_process,
args=(
world_size,
_tp_model_loop,
(model_path, ),
dict(model_config=model_config,
cache_config=cache_config,
adapters=adapters,
in_que=in_que,
out_que=out_que,
world_size=world_size,
trust_remote_code=trust_remote_code),
__find_available_port(),
),
nprocs=world_size,
join=False,
daemon=True,
)
resp: TPResponse = _queue_get_response(out_que, self.mp_context)
if resp.ret_code != 0:
logger.error(f'Init tp model failed with error: {resp.error}')
raise next(err for err in resp.error if err is not None)
self.cache_config = resp.data
def paging_adapters(self, weight_maps: List[AdapterWeightMap]):
"""load adapter."""
if not weight_maps:
return
self.tp_model_in_que.put(weight_maps)
resp: TPResponse = self.tp_model_out_que.get()
if resp.ret_code != 0:
logger.error(f'paging adapters failed with error: {resp.error}')
raise next(err for err in resp.error if err is not None)
def forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (Dict[int, int]): Cache maps to swap in.
swap_out_map (Dict[int, int]): Cache maps to swap out.
"""
with torch.no_grad():
self.tp_model_in_que.put((inputs, swap_in_map, swap_out_map))
resp: TPResponse = _queue_get_response(self.tp_model_out_que,
self.mp_context)
if resp.ret_code != 0:
raise RuntimeError('tp forward failed.')
return resp.data
async def async_forward(self, inputs: ModelInputs, swap_in_map: SwapMap,
swap_out_map: SwapMap):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (Dict[int, int]): Cache maps to swap in.
swap_out_map (Dict[int, int]): Cache maps to swap out.
"""
with torch.no_grad():
self.tp_model_in_que.put((inputs, swap_in_map, swap_out_map))
resp: TPResponse = await _async_queue_get_response(
self.tp_model_out_que, self.mp_context)
if resp.ret_code != 0:
raise RuntimeError('tp forward failed.')
return resp.data
def build_model_agent(model_path: str,
cache_config: CacheConfig,
trust_remote_code: bool,
adapters: Dict[str, str] = None,
tp: int = 1):
"""create model agent."""
model_config = ModelConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code)
if tp == 1:
model_agent = BaseModelAgent(model_path,
model_config=model_config,
cache_config=cache_config,
adapters=adapters,
trust_remote_code=trust_remote_code)
else:
model_agent = TPModelAgent(model_path,
model_config=model_config,
cache_config=cache_config,
world_size=tp,
adapters=adapters,
trust_remote_code=trust_remote_code)
return model_agent
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
import enum
from dataclasses import dataclass, field
from queue import Empty, Queue
from threading import Lock, Thread
from typing import Any, Awaitable, Callable, Dict, List
from lmdeploy.messages import ResponseType
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
def _raise_exception_on_finish(task: asyncio.Task) -> None:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as e:
logger.exception(f'Engine loop failed with error: {e}')
def _ignore_exception_on_finish(task: asyncio.Task) -> None:
try:
task.result()
except asyncio.CancelledError:
return
except Exception as exc:
logger.info(f'task: {task.get_name()} ended.')
logger.debug(f'task: {task.get_name()} exception: {exc}')
class RequestType(enum.Enum):
"""Request type."""
ADD_SESSION = enum.auto()
ADD_MESSAGE = enum.auto()
STOP_SESSION = enum.auto()
END_SESSION = enum.auto()
STOP_ENGINE = enum.auto()
RESUME_ENGINE = enum.auto()
@dataclass
class Request:
"""Request."""
type: RequestType
sender_id: int
req_id: int
data: Any = None
@dataclass
class Response:
"""Response."""
type: ResponseType
sender_id: int
req_id: int
data: Any = None
err_msg: str = ''
ReqList = List[Request]
def _run_until_complete(future: Awaitable):
"""run untile complete."""
try:
event_loop = asyncio.get_event_loop()
except Exception:
logger.warning('Can not found event loop in current thread.'
' Create a new event loop.')
event_loop = asyncio.new_event_loop()
asyncio.set_event_loop(event_loop)
return event_loop.run_until_complete(future)
@dataclass
class RequestSender:
"""Request sender.
Args:
sender_id (int): The id of the sender
"""
sender_id: int
manager: 'RequestManager'
resp_dict: Dict[int, List[Response]] = field(default_factory=dict)
_next_req_id: int = 0
_resp_que: asyncio.Queue = None
_resp_thread_que: Queue = None
@classmethod
def new(cls, sender_id: int, manager: 'RequestManager'):
"""new."""
return cls(sender_id=sender_id, manager=manager)
@property
def resp_que(self):
"""response queue."""
if self.is_thread_safe():
return self.manager.responses
if self.manager._loop_task is None and not self.is_thread_safe():
self.manager.create_loop_task()
if self._resp_que is None:
self._resp_que = asyncio.Queue()
return self._resp_que
@property
def req_que(self):
"""request queue."""
return self.manager.requests
@property
def resp_thread_que(self):
"""response threadsafe queue."""
if self._resp_thread_que is None:
self._resp_thread_que = Queue()
return self._resp_thread_que
@property
def req_thread_que(self):
"""request threadsafe queue."""
return self.manager.thread_requests
@property
def event_loop(self):
"""get event loop."""
return self.manager.event_loop
def is_thread_safe(self):
"""is thread safe."""
return self.manager.is_thread_safe()
def is_loop_alive(self):
"""is loop alive."""
return self.manager.is_loop_alive()
def run_until_complete(self, future: Awaitable):
"""run untile complete."""
return self.manager.run_until_complete(future)
def _resp_get(self):
"""resp_que.get."""
timeout = 1
while True:
if not self.manager.is_loop_alive():
logger.debug('Engine loop is not alive.')
exit(1)
try:
ret = self.resp_thread_que.get(timeout=timeout)
return ret
except Empty:
continue
except Exception as e:
logger.exception(
f'sender[{self.sender_id}] get response failed: {e}')
raise e
async def _async_resp_get(self):
"""get resp.
Different behavior in threadsafe mode.
"""
timeout = 1
async def __no_threadsafe_get():
while True:
if not self.manager.is_loop_alive():
logger.debug('Engine loop is not alive.')
exit(1)
try:
return await asyncio.wait_for(self.resp_que.get(), timeout)
except asyncio.TimeoutError:
continue
except Exception as e:
logger.exception(
f'sender[{self.sender_id}] get response failed: {e}')
raise e
if self.is_thread_safe():
ret = self._resp_get()
await asyncio.sleep(0)
return ret
else:
return await __no_threadsafe_get()
def _req_put(self, reqs: Any):
"""req put."""
self.req_thread_que.put(reqs)
async def _async_req_put(self, reqs: Any):
"""async rq_que put.
Different behavior in threadsafe mode.
"""
if self.is_thread_safe():
self._req_put(reqs)
await asyncio.sleep(0)
else:
await self.req_que.put(reqs)
def _prefetch_resps(self):
"""prefetch from resp que.
Different behavior in threadsafe mode.
"""
if self.is_thread_safe():
resp_que = self.resp_thread_que
else:
resp_que = self.resp_que
num_resps = resp_que.qsize()
for _ in range(num_resps):
resp: Response = resp_que.get_nowait()
req_id = resp.req_id
self._push_resp(req_id, resp)
def _push_resp(self, req_id: int, resp: Response):
"""push response."""
self.resp_dict.setdefault(req_id, [])
self.resp_dict[req_id].append(resp)
def _pop_resp(self, req_id: int, default: Any = None):
"""pop response."""
if req_id not in self.resp_dict:
return default
resps = self.resp_dict[req_id]
ret = resps.pop(0)
if len(resps) == 0:
self.resp_dict.pop(req_id)
return ret
def _gather_request(self, req_types: List[RequestType], data: List[Any]):
"""gather requests."""
if self.manager._loop_task is None and not self.is_thread_safe():
self.manager.create_loop_task()
if not self.is_loop_alive():
logger.error('Engine main loop stopped.')
exit(1)
assert len(req_types) == len(data)
batch_size = len(req_types)
req_ids = list(range(self._next_req_id,
self._next_req_id + batch_size))
self._next_req_id += batch_size
reqs = [
Request(type=rtype,
sender_id=self.sender_id,
req_id=req_id,
data=rdata)
for req_id, rtype, rdata in zip(req_ids, req_types, data)
]
return req_ids, reqs
async def async_batched_send_async(self, req_types: List[RequestType],
data: List[Any]):
"""Batched send request asynchronize."""
req_ids, reqs = self._gather_request(req_types, data)
await self._async_req_put(reqs)
return req_ids
async def async_send_async(self, req_type: RequestType, data: Any):
"""send request asynchronize."""
return (await self.async_batched_send_async(req_types=[req_type],
data=[data]))[0]
def batched_send_async(self, req_types: List[RequestType],
data: List[Any]) -> List[int]:
"""Batched send request asynchronize.
Different behavior in threadsafe mode.
"""
if not self.is_thread_safe():
coro = self.async_batched_send_async(req_types, data)
return self.run_until_complete(coro)
req_ids, reqs = self._gather_request(req_types, data)
self._req_put(reqs)
return req_ids
def send_async(self, req_type: RequestType, data: Any) -> int:
"""send request asynchronize."""
return self.batched_send_async(req_types=[req_type], data=[data])[0]
async def async_recv_any(self, que_timeout: float = None) -> Response:
"""receive any response."""
self._prefetch_resps()
for req_id in self.resp_dict:
ret = self._pop_resp(req_id, default=None)
if ret is not None:
return ret
return await self._async_resp_get()
def recv_any(self, que_timeout: float = None) -> Response:
"""receive any response."""
coro = self.async_recv_any(que_timeout)
return self.run_until_complete(coro)
def recv_all(self, req_id: int, block: bool = True):
"""revceive all response with req_id."""
self._prefetch_resps()
resps = self.resp_dict.pop(req_id, [])
return resps
async def async_recv(self,
req_id: int,
que_timeout: float = None) -> Response:
"""receive response of given request id async."""
ret = self._pop_resp(req_id, default=None)
if ret is not None:
return ret
# check resp que
while True:
resp: Response = await self._async_resp_get()
if resp.req_id != req_id:
self._push_resp(req_id, resp)
else:
return resp
def recv(self, req_id: int, que_timeout: float = None) -> Response:
"""receive response of given request id.
Different behavior in threadsafe mode.
"""
if not self.is_thread_safe():
coro = self.async_recv(req_id, que_timeout)
return self.run_until_complete(coro)
ret = self._pop_resp(req_id, default=None)
if ret is not None:
return ret
# check resp que
while True:
resp: Response = self._resp_get()
if resp.req_id != req_id:
self._push_resp(req_id, resp)
else:
return resp
async def async_send(self,
req_type: RequestType,
data: Any,
que_timeout: float = None):
"""send and receive synchronize."""
req_id = await self.async_send_async(req_type, data)
return await self.async_recv(req_id, que_timeout=que_timeout)
def send(self,
req_type: RequestType,
data: Any,
que_timeout: float = None) -> Response:
"""send and receive synchronize."""
req_id = self.send_async(req_type, data)
return self.recv(req_id, que_timeout=que_timeout)
def response_callback(self, resp: Response):
"""response callback."""
self.resp_que.put_nowait(resp)
class RequestManager:
"""Request manager."""
def __init__(self, thread_safe: bool = False):
self.senders: Dict[int, RequestSender] = dict()
self.callbacks: Dict[RequestType, Callable] = dict()
self.request_priority: List[RequestType] = [
RequestType.STOP_ENGINE, RequestType.STOP_SESSION,
RequestType.END_SESSION, RequestType.ADD_SESSION,
RequestType.ADD_MESSAGE
]
self.requests: asyncio.Queue = None
self._loop_task: asyncio.Future = None
self._loop_coro: Callable = None
self._thread_safe = thread_safe
self._next_sender_id = 0
self._mutex = Lock()
self._loop_thread: Thread = None
self.thread_requests: Queue = None
# every sender has it's own responses, this responses is
# only used in thread safe mode.
self.responses: asyncio.Queue = None
if thread_safe:
self.thread_requests = Queue()
def create_loop_task(self):
"""create coro task."""
logger.debug('creating engine loop task.')
event_loop = asyncio.get_event_loop()
assert self._loop_coro is not None, (
'Please set loop task with manager.start_loop')
loop_unshielded = event_loop.create_task(self._loop_coro(),
name='EngineMainLoop')
loop_unshielded.add_done_callback(_raise_exception_on_finish)
self._loop_task = asyncio.shield(loop_unshielded)
self.requests = asyncio.Queue()
return self._loop_task
@property
def event_loop(self):
"""get event loop."""
if self._loop_task is None:
return None
else:
return self._loop_task.get_loop()
def is_thread_safe(self):
"""is thread safe."""
return self._thread_safe
def start_loop(self, loop: asyncio.Task):
"""start main loop."""
self._loop_coro = loop
def __get_thread_reqs():
"""get thread reqs."""
num_reqs = self.thread_requests.qsize()
reqs = []
for _ in range(num_reqs):
tmp_reqs = self.thread_requests.get_nowait()
if isinstance(tmp_reqs, Request):
tmp_reqs = [tmp_reqs]
reqs += tmp_reqs
return reqs
async def __req_loop():
"""req loop."""
while True:
# get reqs
reqs = __get_thread_reqs()
if len(reqs) > 0:
await self.requests.put(reqs)
else:
await asyncio.sleep(0.02)
def __put_thread_resps(resps: List[Response]):
"""put thread resps."""
for resp in resps:
sender = self.senders.get(resp.sender_id, None)
if sender is None:
continue
sender.resp_thread_que.put_nowait(resp)
async def __resp_loop():
"""resp loop."""
while True:
num_resps = self.responses.qsize()
resps = []
for _ in range(num_resps):
resps.append(self.responses.get_nowait())
if len(resps) > 0:
__put_thread_resps(resps)
else:
await asyncio.sleep(0.02)
def __run_forever(event_loop: asyncio.BaseEventLoop):
"""run forever."""
logger.debug('start thread run forever.')
asyncio.set_event_loop(event_loop)
self.create_loop_task()
req_loop = event_loop.create_task(__req_loop(),
name='RunForeverReqLoop')
req_loop.add_done_callback(_ignore_exception_on_finish)
resp_loop = event_loop.create_task(__resp_loop(),
name='RunForeverRespLoop')
resp_loop.add_done_callback(_ignore_exception_on_finish)
self.event_loop.run_forever()
if self.is_thread_safe():
event_loop = asyncio.new_event_loop()
self.responses = asyncio.Queue()
self._loop_thread = Thread(target=__run_forever,
args=(event_loop, ),
daemon=True)
self._loop_thread.start()
def is_loop_alive(self):
"""check if main loop is alive."""
def __check_threadsafe():
if self._loop_thread is None:
return False
if not self._loop_thread.is_alive():
return False
if self._loop_task is None:
return False
return not self._loop_task.done()
if self.is_thread_safe():
return __check_threadsafe()
if self._loop_task is None:
logger.debug('loop task has not been created.')
return False
if self._loop_task.get_loop() != asyncio.get_event_loop():
logger.warning('Current event loop is different from'
' the one bound to loop task!')
return False
return not self._loop_task.done()
def build_sender(self):
"""create a new sender."""
with self._mutex:
sender_id = self._next_sender_id
self._next_sender_id += 1
new_sender = RequestSender.new(sender_id, self)
self.senders[sender_id] = new_sender
return new_sender
def has_requests(self):
"""has unprocessed request."""
if self.requests is None:
return False
return not self.requests.empty()
def get_all_requests(self) -> Dict[RequestType, Request]:
"""get all requests in current queue."""
num_reqs = self.requests.qsize()
reqs: ReqList = []
for _ in range(num_reqs):
elem = self.requests.get_nowait()
if isinstance(elem, Request):
elem = [elem]
reqs += elem
# gather requests
reqs_by_type: Dict[RequestType, Request] = dict(
(t, []) for t in RequestType)
for req in reqs:
reqs_by_type[req.type].append(req)
return reqs_by_type
def bind_func(self, req_type: RequestType, callback: Callable):
"""bind handler for given request type."""
self.callbacks[req_type] = callback
def set_request_priority(self, priority: List[RequestType]):
"""set the priority of request type."""
self.request_priority = priority
def response(self, resp: Response):
"""send response."""
if resp.sender_id not in self.senders:
logger.warning(f'sender {resp.sender_id} not exist. '
f'Send {resp} failed.')
return
self.senders[resp.sender_id].response_callback(resp)
def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs):
"""process reqs with given req type."""
# get callback
func = self.callbacks.get(req_type, None)
if func is not None:
func(reqs, **kwargs)
else:
# TODO: send error message
for req in reqs:
resp = Response(ResponseType.HANDLER_NOT_EXIST,
sender_id=req.sender_id,
req_id=req.req_id,
err_msg=(f'callback for {req_type}'
' not exists.'))
self.response(resp)
def step(self, **kwargs):
"""handle requests.
Should only be called in loop task.
"""
reqs_by_type = self.get_all_requests()
# handle requests
for req_type in self.request_priority:
# request exists
if req_type not in reqs_by_type or len(reqs_by_type) == 0:
continue
reqs: ReqList = reqs_by_type[req_type]
self.process_request(req_type, reqs, **kwargs)
def run_until_complete(self, future: Awaitable):
"""run untile complete."""
return _run_until_complete(future)
# Copyright (c) OpenMMLab. All rights reserved.
from .alibi_pagedattention import alibi_paged_attention_fwd
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .fused_rotary_emb import fused_rotary_emb
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rerope_attention import rerope_attention_fwd
from .rms_norm import rms_norm
__all__ = [
'apply_rotary_pos_emb', 'fused_rotary_emb', 'paged_attention_fwd',
'alibi_paged_attention_fwd', 'fill_kv_cache', 'multinomial_sampling',
'rms_norm', 'rerope_attention_fwd'
]
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/ModelTC/lightllm
import math
import torch
import triton
import triton.language as tl
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
assert triton.__version__ >= '2.1.0'
LOG2 = math.log(2)
@triton.jit
def tl_pow(a, b):
"""triton pow."""
return tl.exp(b * tl.log(a))
@triton.jit
def tl_2pow(b):
"""triton pow2."""
return tl.exp(b * LOG2)
@triton.jit
def tl_log2(a):
"""triton log2."""
return tl.log(a) / LOG2
@triton.jit
def _get_interleave_power_of_2(i, n):
"""get interleave power of 2."""
start = -tl_2pow(3 - tl_log2(n))
start = tl_2pow(start)
ratio = start
return start * tl_pow(ratio, i)
@triton.jit
def get_slope(i, n):
"""get slope."""
closest_power_of_2 = tl_2pow(tl_log2(n).to(tl.int32))
if i < closest_power_of_2:
return _get_interleave_power_of_2(i, closest_power_of_2)
else:
return _get_interleave_power_of_2((i - closest_power_of_2) * 2,
2 * closest_power_of_2)
@triton.jit
def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr,
BLOCK: tl.constexpr):
if num_sub_blocks > 1:
offs_sub = tl.arange(0, num_sub_blocks)
offs_n = tl.arange(0, BLOCK // num_sub_blocks)
ret = tl.load(offset_ptr + block_id * num_sub_blocks + offs_sub)[
None, :] * BLOCK // num_sub_blocks + offs_n[:, None]
return tl.ravel(ret)
else:
offs_n = tl.arange(0, BLOCK)
return tl.load(offset_ptr + block_id) * BLOCK + offs_n
@triton.jit
def _fwd_split_kernel(
Q,
K,
V,
sm_scale,
alibi_scale,
B_kvlen,
Block_offsets,
Acc_out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_ok,
stride_obs,
stride_oh,
stride_od,
stride_boffb,
head_offset,
num_heads,
kv_group_num,
block_per_cta,
num_sub_blocks: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""first step kernel of split k attention."""
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
split_k_id = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = 1
cur_batch_kv_len = tl.load(B_kvlen + cur_batch)
history_len = cur_batch_kv_len - cur_batch_seq_len
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = (cur_batch * stride_qbs + cur_head * stride_qh +
offs_d * stride_qd)
off_k = (cur_kv_head * stride_kh + offs_d[None, :] * stride_kd)
off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd)
q = tl.load(Q + off_q).to(tl.float32)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_offset_ptrs = Block_offsets + cur_batch * stride_boffb
head_slope = get_slope(
cur_head.to(tl.float32) + head_offset, num_heads.to(tl.float32))
# initialize pointer to m and l
m_i = -float('inf')
l_i = float(0)
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
kv_len_per_prog = block_per_cta * BLOCK_N
loop_start = kv_len_per_prog * split_k_id
loop_end = tl.minimum(loop_start + kv_len_per_prog, cur_batch_kv_len)
# load block offset
start_block_id = loop_start // BLOCK_N
b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
num_sub_blocks, BLOCK_N)
for start_n in range(loop_start, loop_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n[:, None]) < cur_batch_kv_len
# -- compute qk ----
k = tl.load(
k_ptrs + b_offset[:, None] * stride_kbs,
mask=mask,
other=0.0,
)
v = tl.load(
v_ptrs + b_offset[:, None] * stride_vbs,
mask=mask,
other=0.0,
)
# prefetch b_offset
if start_n + BLOCK_N < loop_end:
start_block_id += 1
b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
num_sub_blocks, BLOCK_N)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
mask = start_n + offs_n
bias = mask.to(tl.float32) * (head_slope * alibi_scale)
qk += bias
# NOTE: inf - inf = nan, and nan will leads to error
qk = tl.where(
history_len >= (start_n + offs_n),
qk,
-float('inf'),
)
# -- compute p, m_i and l_i
m_i_new = tl.maximum(m_i, tl.max(qk, 0))
p = tl.exp(qk - m_i_new)
alpha = tl.exp(m_i - m_i_new)
l_i_new = alpha * l_i + tl.sum(p, 0)
# -- update output accumulator --
# scale acc
acc = acc * alpha
# update acc
p_new = p.to(v.dtype)
acc += tl.sum(p_new[:, None] * v, 0)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_acc = (cur_batch * stride_obs + split_k_id * stride_ok +
cur_head * stride_oh + offs_d * stride_od)
tl.store(Acc_out + off_acc, acc)
off_meta = (cur_batch * stride_obs + split_k_id * stride_ok +
cur_head * stride_oh + BLOCK_DMODEL)
tl.store(Acc_out + off_meta + tl.arange(0, 1), m_i)
tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i)
@triton.jit
def _reduce_split_kernel(
Acc,
Out,
stride_ak,
stride_abs,
stride_ah,
stride_ad,
stride_obs,
stride_oh,
stride_od,
SPLIT_K: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
"""second step kernel of split k attention."""
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
# initialize offsets
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_k = tl.arange(0, SPLIT_K)
offs_acc = (cur_batch * stride_abs + cur_head * stride_ah +
offs_k[:, None] * stride_ak + offs_d[None, :] * stride_ad)
offs_mi = (cur_batch * stride_abs + cur_head * stride_ah +
stride_ak * offs_k + BLOCK_DMODEL)
acc_k = tl.load(Acc + offs_acc)
m_k = tl.load(Acc + offs_mi)
l_k = tl.load(Acc + offs_mi + 1)
m_max = tl.max(m_k, 0)
alpha = tl.exp(m_k - m_max)
acc_k = acc_k * alpha[:, None]
l_k = l_k * alpha
acc = tl.sum(acc_k, 0)
l_sum = tl.sum(l_k, 0)
acc = acc / l_sum
out_offs = (cur_batch * stride_obs + cur_head * stride_oh +
offs_d * stride_od)
tl.store(Out + out_offs, acc)
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
alibi_scale,
B_Start_Loc,
B_Seqlen,
B_kvlen,
Block_offsets,
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_boffb,
head_offset,
num_heads,
kv_group_num,
num_sub_blocks: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""forward kernel."""
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_kv_len = tl.load(B_kvlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
history_len = cur_batch_kv_len - cur_batch_seq_len
block_start_loc = BLOCK_M * start_m
head_slope = get_slope(
cur_head.to(tl.float32) + head_offset, num_heads.to(tl.float32))
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd)
off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd)
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_offset_ptrs = Block_offsets + cur_batch * stride_boffb
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)
b_offset = _load_block_offsets(block_offset_ptrs, 0, num_sub_blocks,
BLOCK_N)
for start_n in range(0, block_mask * cur_batch_kv_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + b_offset[None, :] * stride_kbs,
mask=(start_n + offs_n[None, :]) < cur_batch_kv_len,
other=0.0,
)
v = tl.load(
v_ptrs + b_offset[:, None] * stride_vbs,
mask=(start_n + offs_n[:, None]) < cur_batch_kv_len,
other=0.0,
)
if start_n + BLOCK_N < cur_batch_kv_len:
start_block_id = start_n // BLOCK_N + 1
b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
num_sub_blocks, BLOCK_N)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
mask = start_n + offs_n[None, :]
bias = mask.to(tl.float32) * (head_slope * alibi_scale)
qk += bias
# NOTE: inf - inf = nan, and nan will leads to error
qk = tl.where(
(history_len + offs_m[:, None]) >= mask,
qk,
float(-1e30),
)
# -- compute p, m_i and l_i
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
p = tl.exp(qk - m_i_new[:, None])
alpha = tl.exp(m_i - m_i_new)
l_i_new = alpha * l_i + tl.sum(p, 1)
# -- update output accumulator --
# scale acc
acc = acc * alpha[:, None]
# update acc
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
@torch.no_grad()
def alibi_paged_attention_fwd(q: Tensor,
k: Tensor,
v: Tensor,
o: Tensor,
block_offsets: Tensor,
b_start_loc: Tensor,
b_seq_len: Tensor,
b_kv_seq_len: Tensor,
max_input_len: int,
head_offset: int = 0,
num_heads: int = -1,
alibi_scale: float = 1.0):
"""Paged attention forward with alibi bias.
Args:
q (Tensor): Query state.
k (Tensor): Key state caches.
v (Tensor): Value state caches.
o (Tensor): Output state.
block_offsets (Tensor): The block offset of key and value.
b_start_loc (Tensor): Start token location of each data in batch.
b_seq_len (Tensor): Query length for each data in batch.
b_kv_seq_len (Tensor): Key/Value length for each data in batch.
max_input_len (int): The max input length.
head_offset (int): The offset of the start head. Head might be
partitioned when tensor parallel inference.
num_heads (int): The number of heads. Head might be partitioned when
tensor parallel inference.
BLOCK (int): The kernel block size.
"""
def _kernel_meta():
device = q.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5) # 计算scale系数
batch, head = b_seq_len.shape[0], q.shape[-2]
kv_group_num = q.shape[-2] // k[0].shape[-2]
if num_heads <= 0:
num_heads = head
BLOCK = 64 if k.size(1) < 16 else k.size(1)
num_sub_blocks = BLOCK // k.size(1)
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
num_warps = 4 if Lk <= 64 else 8
kernel_meta = _kernel_meta()
is_decoding = q.shape[-3] == b_seq_len.size(0)
if not is_decoding:
_fwd_kernel[grid](q,
k,
v,
sm_scale,
alibi_scale,
b_start_loc,
b_seq_len,
b_kv_seq_len,
block_offsets,
o,
q.stride(-3),
q.stride(-2),
q.stride(-1),
k.stride(-3),
k.stride(-2),
k.stride(-1),
v.stride(-3),
v.stride(-2),
v.stride(-1),
o.stride(-3),
o.stride(-2),
o.stride(-1),
block_offsets.stride(0),
head_offset=head_offset,
num_heads=num_heads,
kv_group_num=kv_group_num,
num_sub_blocks=num_sub_blocks,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
else:
SPLIT_K = 4
grid = (batch, head, SPLIT_K)
block_per_cta = triton.cdiv(block_offsets.size(-1), SPLIT_K)
acc = q.new_empty(batch, head, SPLIT_K, Lq + 2, dtype=torch.float32)
_fwd_split_kernel[grid](q,
k,
v,
sm_scale,
alibi_scale,
b_kv_seq_len,
block_offsets,
acc,
stride_qbs=q.stride(-3),
stride_qh=q.stride(-2),
stride_qd=q.stride(-1),
stride_kbs=k.stride(-3),
stride_kh=k.stride(-2),
stride_kd=k.stride(-1),
stride_vbs=v.stride(-3),
stride_vh=v.stride(-2),
stride_vd=v.stride(-1),
stride_ok=acc.stride(-2),
stride_obs=acc.stride(-4),
stride_oh=acc.stride(-3),
stride_od=acc.stride(-1),
stride_boffb=block_offsets.stride(0),
head_offset=head_offset,
num_heads=num_heads,
kv_group_num=kv_group_num,
block_per_cta=block_per_cta,
num_sub_blocks=num_sub_blocks,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=4,
num_stages=1,
**kernel_meta)
grid = (batch, head)
_reduce_split_kernel[grid](acc,
o,
stride_ak=acc.stride(-2),
stride_abs=acc.stride(-4),
stride_ah=acc.stride(-3),
stride_ad=acc.stride(-1),
stride_obs=o.stride(-3),
stride_oh=o.stride(-2),
stride_od=o.stride(-1),
SPLIT_K=SPLIT_K,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
@triton.jit
def apply_rotary_pos_emb_kernel(
Q,
COS,
SIN,
POS,
Q_EMB,
seq_len,
stride_qh: tl.constexpr,
BLOCK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""apply rotary on key OR query kernel."""
seq_block_id = tl.program_id(0)
head_id = tl.program_id(1)
pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK)
pos_ids = tl.load(POS + pos_offset, pos_offset < seq_len, other=-1)
feat_size = BLOCK_N * 2
feat_offset_l = tl.arange(0, BLOCK_N)
feat_offset_h = BLOCK_N + feat_offset_l
cs_offset_l = pos_ids[:, None] * feat_size + feat_offset_l[None, :]
cs_offset_h = pos_ids[:, None] * feat_size + feat_offset_h[None, :]
pos_ids_mask = pos_ids[:, None] >= 0
cos_l = tl.load(COS + cs_offset_l, mask=pos_ids_mask)
cos_h = tl.load(COS + cs_offset_h, mask=pos_ids_mask)
sin_l = tl.load(SIN + cs_offset_l, mask=pos_ids_mask)
sin_h = tl.load(SIN + cs_offset_h, mask=pos_ids_mask)
q_offset_seq = pos_offset[:, None] * stride_qh + head_id * feat_size
q_offset_l = q_offset_seq + feat_offset_l[None, :]
q_offset_h = q_offset_seq + feat_offset_h[None, :]
pos_mask = pos_offset[:, None] < seq_len
q_l = tl.load(Q + q_offset_l, mask=pos_mask)
q_h = tl.load(Q + q_offset_h, mask=pos_mask)
q_emb_l = q_l * cos_l - q_h * sin_l
q_emb_h = q_h * cos_h + q_l * sin_h
tl.store(Q_EMB + q_offset_l, q_emb_l, mask=pos_mask)
tl.store(Q_EMB + q_offset_h, q_emb_h, mask=pos_mask)
@triton.jit
def apply_rotary_pos_emb_qk_kernel(
Q,
K,
COS,
SIN,
POS,
Q_EMB,
K_EMB,
seq_len,
stride_qh: tl.constexpr,
stride_kh: tl.constexpr,
BLOCK: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""apply rotary on key AND query kernel."""
seq_block_id = tl.program_id(0)
head_id = tl.program_id(1)
pos_offset = seq_block_id * BLOCK + tl.arange(0, BLOCK)
pos_ids = tl.load(POS + pos_offset, pos_offset < seq_len, other=-1)
feat_size = BLOCK_N * 2
feat_offset_l = tl.arange(0, BLOCK_N)
feat_offset_h = BLOCK_N + feat_offset_l
cs_offset_l = pos_ids[:, None] * feat_size + feat_offset_l[None, :]
cs_offset_h = pos_ids[:, None] * feat_size + feat_offset_h[None, :]
pos_ids_mask = pos_ids[:, None] >= 0
cos_l = tl.load(COS + cs_offset_l, mask=pos_ids_mask)
cos_h = tl.load(COS + cs_offset_h, mask=pos_ids_mask)
sin_l = tl.load(SIN + cs_offset_l, mask=pos_ids_mask)
sin_h = tl.load(SIN + cs_offset_h, mask=pos_ids_mask)
q_offset_seq = pos_offset[:, None] * stride_qh + head_id * feat_size
q_offset_l = q_offset_seq + feat_offset_l[None, :]
q_offset_h = q_offset_seq + feat_offset_h[None, :]
k_offset_seq = pos_offset[:, None] * stride_kh + head_id * feat_size
k_offset_l = k_offset_seq + feat_offset_l[None, :]
k_offset_h = k_offset_seq + feat_offset_h[None, :]
pos_mask = pos_offset[:, None] < seq_len
q_l = tl.load(Q + q_offset_l, mask=pos_mask)
q_h = tl.load(Q + q_offset_h, mask=pos_mask)
k_l = tl.load(K + k_offset_l, mask=pos_mask)
k_h = tl.load(K + k_offset_h, mask=pos_mask)
q_emb_l = q_l * cos_l - q_h * sin_l
q_emb_h = q_h * cos_h + q_l * sin_h
k_emb_l = k_l * cos_l - k_h * sin_l
k_emb_h = k_h * cos_h + k_l * sin_h
tl.store(Q_EMB + q_offset_l, q_emb_l, mask=pos_mask)
tl.store(Q_EMB + q_offset_h, q_emb_h, mask=pos_mask)
tl.store(K_EMB + k_offset_l, k_emb_l, mask=pos_mask)
tl.store(K_EMB + k_offset_h, k_emb_h, mask=pos_mask)
@torch.inference_mode()
def apply_rotary_pos_emb(q: Tensor,
k: Tensor,
cos: Tensor,
sin: Tensor,
position_ids: Tensor,
position_ids_1d: Tensor = None,
q_embed: Tensor = None,
k_embed: Tensor = None):
"""Apply rotary positional embedding on query and key.
Args:
q (Tensor): Query state.
k (Tensor): Key state.
cos (Tensor): cosine matrix (seq_len, dim).
sin (Tensor): sine matrix (seq_len, dim).
position_ids (Tensor): Position ids of q and k.
position_ids_1d (Tensor): 1d Position ids.
q_embed (Tensor): output q, can be same as q
k_embed (Tensor): output k, can be same as k
Returns:
Tuple[Tensor, Tensor]: Embedded query and key.
"""
if not q.is_contiguous():
q = q.contiguous()
if not k.is_contiguous():
k = k.contiguous()
if cos.device != q.device or cos.dtype != q.dtype:
cos = cos.to(device=q.device, dtype=q.dtype)
if sin.device != q.device or sin.dtype != q.dtype:
sin = sin.to(device=q.device, dtype=q.dtype)
if position_ids_1d is None:
seq_length = position_ids[..., -1] + 1
position_ids_1d = [ids[:l] for ids, l in zip(position_ids, seq_length)]
position_ids_1d = torch.cat(position_ids_1d)
if q_embed is None:
q_embed = torch.empty_like(q)
if k_embed is None:
k_embed = torch.empty_like(k)
seq_len = position_ids_1d.size(-1)
BLOCK = 32
num_heads_q = q.size(-2)
num_heads_k = k.size(-2)
num_warps = 4
num_stages = 2
device = q.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
if num_heads_k == num_heads_q:
grid = [triton.cdiv(seq_len, BLOCK), num_heads_q]
apply_rotary_pos_emb_qk_kernel[grid](q,
k,
cos,
sin,
position_ids_1d,
q_embed,
k_embed,
seq_len=seq_len,
stride_qh=q.stride(-3),
stride_kh=k.stride(-3),
BLOCK=BLOCK,
BLOCK_N=q.size(-1) // 2,
num_warps=num_warps,
num_stages=num_stages,
stream=stream,
device=device_idx,
device_type=device_type)
else:
grid_q = [triton.cdiv(seq_len, BLOCK), num_heads_q]
grid_k = [triton.cdiv(seq_len, BLOCK), num_heads_k]
apply_rotary_pos_emb_kernel[grid_q](q,
cos,
sin,
position_ids_1d,
q_embed,
seq_len=seq_len,
stride_qh=q.stride(-3),
BLOCK=BLOCK,
BLOCK_N=q.size(-1) // 2,
num_warps=num_warps,
num_stages=num_stages,
stream=stream,
device=device_idx,
device_type=device_type)
apply_rotary_pos_emb_kernel[grid_k](k,
cos,
sin,
position_ids_1d,
k_embed,
seq_len=seq_len,
stride_qh=k.stride(-3),
BLOCK=BLOCK,
BLOCK_N=k.size(-1) // 2,
num_warps=num_warps,
num_stages=num_stages,
stream=stream,
device=device_idx,
device_type=device_type)
return q_embed, k_embed
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
@triton.jit
def _div_up(val, other):
return (val + other - 1) // other
@triton.jit
def _fill_kv_cache_kernel(
KStates,
VStates,
KCaches,
VCaches,
QStartLoc,
QSeqLens,
KVSeqLens,
BlockOffsets,
num_heads: tl.constexpr,
head_dim: tl.constexpr,
stride_kss,
stride_ksh,
stride_ksd,
stride_vss,
stride_vsh,
stride_vsd,
stride_kcn: tl.constexpr,
stride_kcb: tl.constexpr,
stride_kch: tl.constexpr,
stride_kcd: tl.constexpr,
stride_vcn: tl.constexpr,
stride_vcb: tl.constexpr,
stride_vch: tl.constexpr,
stride_vcd: tl.constexpr,
stride_boff,
BLOCK: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_H: tl.constexpr,
):
"""fill kv cache kernel."""
batch_id = tl.program_id(0)
block_id = tl.program_id(1)
# initialize
h_off = tl.arange(0, BLOCK_H)
d_off = tl.arange(0, BLOCK_D)
q_startloc = tl.load(QStartLoc + batch_id)
q_seqlen = tl.load(QSeqLens + batch_id)
kv_seqlen = tl.load(KVSeqLens + batch_id)
history_seqlen = kv_seqlen - q_seqlen
block0_first_tokenloc = history_seqlen % BLOCK
state_token_offset = tl.maximum(block_id * BLOCK - block0_first_tokenloc,
0)
kv_block_id = _div_up(history_seqlen + 1, BLOCK) - 1 + block_id
kv_block_id = min(kv_block_id, stride_boff - 1)
block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)
cur_startloc = q_startloc + state_token_offset
ks_ptr = KStates + cur_startloc * stride_kss
vs_ptr = VStates + cur_startloc * stride_vss
kc_ptr = KCaches + block_off * stride_kcn
vc_ptr = VCaches + block_off * stride_vcn
c_first_tokenloc = block0_first_tokenloc
if block_id != 0:
c_first_tokenloc *= 0
c_last_tokenloc = tl.minimum(
BLOCK, q_seqlen + block0_first_tokenloc - block_id * BLOCK)
for bidx in range(c_first_tokenloc, c_last_tokenloc):
sidx = bidx - c_first_tokenloc
mask = (h_off[:, None] < num_heads) & (d_off[None, :] < head_dim)
k = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh +
d_off[None, :] * stride_ksd,
mask=mask)
tl.store(kc_ptr + bidx * stride_kcb + h_off[:, None] * stride_kch +
d_off[None, :] * stride_kcd,
k,
mask=mask)
v = tl.load(vs_ptr + sidx * stride_vss + h_off[:, None] * stride_vsh +
d_off[None, :] * stride_vsd,
mask=mask)
tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch +
d_off[None, :] * stride_vcd,
v,
mask=mask)
@torch.inference_mode()
def fill_kv_cache(k_states: Tensor, v_states: Tensor, k_caches: Tensor,
v_caches: Tensor, q_start_loc: Tensor, q_seq_length: Tensor,
kv_seq_length: Tensor, max_q_seq_length: int,
block_offsets: Tensor):
"""fill key/value state to cache for paged attention."""
def _kernel_meta():
device = k_states.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
block_offsets = block_offsets.contiguous()
batch_size = block_offsets.size(0)
block_size, num_heads, head_dim = k_caches.size()[1:]
max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1
BLOCK = block_size
BLOCK_H = triton.next_power_of_2(num_heads)
BLOCK_D = triton.next_power_of_2(head_dim)
grid = [batch_size, max_num_blocks]
kernel_meta = _kernel_meta()
_fill_kv_cache_kernel[grid](
k_states,
v_states,
k_caches,
v_caches,
q_start_loc,
q_seq_length,
kv_seq_length,
block_offsets,
num_heads=num_heads,
head_dim=head_dim,
stride_kss=k_states.stride(-3),
stride_ksh=k_states.stride(-2),
stride_ksd=k_states.stride(-1),
stride_vss=v_states.stride(-3),
stride_vsh=v_states.stride(-2),
stride_vsd=v_states.stride(-1),
stride_kcn=k_caches.stride(0),
stride_kcb=k_caches.stride(1),
stride_kch=k_caches.stride(2),
stride_kcd=k_caches.stride(3),
stride_vcn=v_caches.stride(0),
stride_vcb=v_caches.stride(1),
stride_vch=v_caches.stride(2),
stride_vcd=v_caches.stride(3),
stride_boff=block_offsets.stride(0),
BLOCK=BLOCK,
BLOCK_D=BLOCK_D,
BLOCK_H=BLOCK_H,
**kernel_meta,
)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
@triton.jit
def _fused_rotary_emb_kernel(
Q, K, PostionIds, InvFreq, scaling_factor, OutQ, OutK, stride_bq,
stride_sq, stride_hq: tl.constexpr, stride_dq: tl.constexpr, stride_bk,
stride_sk, stride_hk: tl.constexpr, stride_dk: tl.constexpr, stride_bp,
stride_sp, max_seq_len, BLOCK: tl.constexpr, BLOCK_HQ: tl.constexpr,
BLOCK_HK: tl.constexpr, BLOCK_F: tl.constexpr):
"""fused rotary emb kernel."""
batch_id = tl.program_id(0)
seq_block_id = tl.program_id(1)
s_off = seq_block_id * BLOCK + tl.arange(0, BLOCK)[:, None]
f_off = tl.arange(0, BLOCK_F)[None, :]
s_mask = s_off < max_seq_len
bp_off = stride_bp * batch_id
p_off = bp_off + stride_sp * s_off
sq_off = batch_id * stride_bq + s_off * stride_sq
q0_off = sq_off + f_off * stride_dq
q1_off = q0_off + BLOCK_F * stride_dq
sk_off = batch_id * stride_bk + s_off * stride_sk
k0_off = sk_off + f_off * stride_dk
k1_off = k0_off + BLOCK_F * stride_dk
inv_freq = tl.load(InvFreq + f_off).to(tl.float32)
position_ids = tl.load(PostionIds + p_off, mask=s_mask).to(tl.float32)
position_ids = position_ids / scaling_factor
# pos_freq = tl.dot(position_ids, inv_freq)
pos_freq = position_ids * inv_freq
cos = tl.cos(pos_freq).to(Q.dtype.element_ty)
sin = tl.sin(pos_freq).to(Q.dtype.element_ty)
for h in range(BLOCK_HQ):
q0 = tl.load(Q + q0_off + h * stride_hq, mask=s_mask)
q1 = tl.load(Q + q1_off + h * stride_hq, mask=s_mask)
q0_out = q0 * cos - q1 * sin
tl.store(OutQ + q0_off + h * stride_hq, q0_out, mask=s_mask)
q1_out = q1 * cos + q0 * sin
tl.store(OutQ + q1_off + h * stride_hq, q1_out, mask=s_mask)
for h in range(BLOCK_HK):
k0 = tl.load(K + k0_off + h * stride_hk, mask=s_mask)
k1 = tl.load(K + k1_off + h * stride_hk, mask=s_mask)
k0_out = k0 * cos - k1 * sin
tl.store(OutK + k0_off + h * stride_hk, k0_out, mask=s_mask)
k1_out = k1 * cos + k0 * sin
tl.store(OutK + k1_off + h * stride_hk, k1_out, mask=s_mask)
def fused_rotary_emb(q: Tensor,
k: Tensor,
position_ids: torch.LongTensor,
inv_freq: Tensor,
scaling_factor: float,
out_q: Tensor = None,
out_k: Tensor = None):
"""Fuse `rotary_embedding` and `apply_rotary_pos_emb`."""
def _kernel_meta():
device = q.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
if out_q is None:
out_q = torch.empty_like(q)
else:
assert q.stride() == out_q.stride()
if out_k is None:
out_k = torch.empty_like(k)
else:
assert k.stride() == out_k.stride()
assert q.dim() == 4
assert k.dim() == 4
assert q.size(0) == position_ids.size(0)
BLOCK = 32
BLOCK_HQ = q.size(-2)
BLOCK_HK = k.size(-2)
BLOCK_F = q.size(-1) // 2
batch_size = q.size(0)
max_seq_len = q.size(1)
kernel_meta = _kernel_meta()
num_warps = 4
grid = (batch_size, triton.cdiv(max_seq_len, BLOCK))
_fused_rotary_emb_kernel[grid](q,
k,
position_ids,
inv_freq,
scaling_factor,
out_q,
out_k,
stride_bq=q.stride(0),
stride_sq=q.stride(1),
stride_hq=q.stride(2),
stride_dq=q.stride(3),
stride_bk=k.stride(0),
stride_sk=k.stride(1),
stride_hk=k.stride(2),
stride_dk=k.stride(3),
stride_bp=position_ids.stride(0),
stride_sp=position_ids.stride(1),
max_seq_len=max_seq_len,
BLOCK=BLOCK,
BLOCK_HQ=BLOCK_HQ,
BLOCK_HK=BLOCK_HK,
BLOCK_F=BLOCK_F,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
return out_q, out_k
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
def _next_pow_of_2(x):
"""get next power of 2."""
return 1 << (x - 1).bit_length()
@triton.jit
def _x_a_mm_kernel(
X,
LoRA_A,
XA,
B_start_loc,
B_seq_lens,
B_adapter_id,
Rank_page_table,
Rank_page_start,
Ranks,
stride_xs,
stride_xh,
stride_las,
stride_lah,
stride_xas,
stride_xar,
stride_ptb,
rank_step,
BLOCK_M: tl.constexpr,
BLOCK_R: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
"""xa mm kernel."""
cur_batch = tl.program_id(0)
start_m = tl.program_id(1)
r_off = tl.arange(0, BLOCK_R)
seq_len = tl.load(B_seq_lens + cur_batch)
if start_m * BLOCK_M >= seq_len:
return
start_loc = tl.load(B_start_loc + cur_batch)
adapter_id = tl.load(B_adapter_id + cur_batch)
rank = tl.load(Ranks + adapter_id) // rank_step
page_start = tl.load(Rank_page_start + adapter_id)
page_table_off = adapter_id * stride_ptb + r_off + page_start
rank_mask = r_off < rank
page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
dm_off = tl.arange(0, BLOCK_DMODEL)
x_off = (start_loc + m_off) * stride_xs
xs_mask = m_off < seq_len
la_page_off = page_table * stride_las
acc = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32)
# compute acc
for start_h in range(0, BLOCK_H, BLOCK_DMODEL):
cur_dm_off = start_h + dm_off
h_mask = cur_dm_off < BLOCK_H
# load x
xh_off = cur_dm_off * stride_xh
x_mask = xs_mask[:, None] and h_mask[None, :]
x = tl.load(X + x_off[:, None] + xh_off[None, :],
mask=x_mask,
other=0.0)
# load lora a
lah_off = cur_dm_off * stride_lah
la_mask = rank_mask[None, :] and h_mask[:, None]
la = tl.load(LoRA_A + la_page_off[None, :] + lah_off[:, None],
mask=la_mask,
other=0.0)
# compute
acc += tl.dot(x, la)
acc = acc.to(X.dtype.element_ty)
xa_off = (start_loc + m_off) * stride_xas
xas_mask = xs_mask
xa_mask = xas_mask[:, None] and rank_mask[None, :]
tl.store(XA + xa_off[:, None] + r_off[None, :] * stride_xar,
acc,
mask=xa_mask)
@triton.jit
def _acc_b_mm_kernel(
XA,
LoRA_B,
Out,
B_start_loc,
B_seq_lens,
B_adapter_id,
B_scaling,
Rank_page_table,
Rank_page_start,
Ranks,
stride_xas,
stride_xar,
stride_os,
stride_oh,
stride_lbs,
stride_lbh,
stride_ptb,
BLOCK_M: tl.constexpr,
BLOCK_R: tl.constexpr,
BLOCK_HO: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
cur_batch = tl.program_id(0)
start_m = tl.program_id(1)
r_off = tl.arange(0, BLOCK_R)
seq_len = tl.load(B_seq_lens + cur_batch)
if start_m * BLOCK_M >= seq_len:
return
start_loc = tl.load(B_start_loc + cur_batch)
adapter_id = tl.load(B_adapter_id + cur_batch)
scaling = tl.load(B_scaling + cur_batch)
rank = tl.load(Ranks + adapter_id)
page_start = tl.load(Rank_page_start + adapter_id)
page_table_off = adapter_id * stride_ptb + r_off + page_start
rank_mask = r_off < rank
page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
m_off = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
dm_off = tl.arange(0, BLOCK_DMODEL)
lb_page_off = page_table * stride_lbs
xs_mask = m_off < seq_len
o_off = (start_loc + m_off) * stride_os
os_mask = xs_mask
xa_off = (start_loc + m_off) * stride_xas
xa_mask = xs_mask[:, None] and rank_mask[None, :]
acc = tl.load(XA + xa_off[:, None] + r_off[None, :] * stride_xar,
mask=xa_mask,
other=0.0)
acc = acc.to(LoRA_B.dtype.element_ty)
# compute output
for start_h in range(0, BLOCK_HO, BLOCK_DMODEL):
cur_dm_off = start_h + dm_off
h_mask = cur_dm_off < BLOCK_HO
# load lora b
lbh_off = cur_dm_off * stride_lbh
lb_mask = rank_mask[:, None] and h_mask[None, :]
lb = tl.load(LoRA_B + lb_page_off[:, None] + lbh_off[None, :],
mask=lb_mask,
other=0)
# compute
out = tl.dot(acc, lb)
out = out.to(lb.dtype)
out = out * scaling
# store o
oh_off = cur_dm_off * stride_oh
o_mask = os_mask[:, None] and h_mask[None, :]
tl.store(Out + o_off[:, None] + oh_off[None, :], out, mask=o_mask)
@torch.inference_mode()
def mbgmm_a(x: Tensor,
lora_a: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
adapter_ids: Tensor,
rank_page_table: Tensor,
ranks: Tensor,
rank_page_start: Tensor,
max_seq_len: int,
max_rank: int,
rank_step: int = 1):
"""mbgmm_a."""
def _kernel_meta():
device = x.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
assert x.dim() == 2
assert lora_a.dim() == 2
assert rank_page_table.dim() == 2
head_size = x.size(-1)
batch_size = len(q_seqlens)
max_rank = max_rank // rank_step
BLOCK_M = 32
BLOCK_R = _next_pow_of_2(max_rank)
if BLOCK_R < 16:
BLOCK_R = 16
BLOCK_H = head_size
BLOCK_DMODEL = 64
num_warps = 4
grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)]
xa = x.new_empty((x.size(0), max_rank))
kernel_meta = _kernel_meta()
_x_a_mm_kernel[grid](x,
lora_a,
xa,
q_start_loc,
q_seqlens,
adapter_ids,
Rank_page_table=rank_page_table,
Rank_page_start=rank_page_start,
Ranks=ranks,
stride_xs=x.stride(0),
stride_xh=x.stride(1),
stride_las=lora_a.stride(0),
stride_lah=lora_a.stride(1),
stride_xas=xa.stride(0),
stride_xar=xa.stride(1),
stride_ptb=rank_page_table.stride(0),
rank_step=rank_step,
BLOCK_M=BLOCK_M,
BLOCK_R=BLOCK_R,
BLOCK_H=BLOCK_H,
BLOCK_DMODEL=BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
return xa
@torch.inference_mode()
def mbgmm_b(xa: Tensor,
lora_b: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
adapter_ids: Tensor,
scaling: Tensor,
rank_page_table: Tensor,
ranks: Tensor,
rank_page_start: Tensor,
max_seq_len: int,
max_rank: int,
out_size: int = None):
"""mbgmm_b."""
def _kernel_meta():
device = xa.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
assert xa.dim() == 2
assert lora_b.dim() == 2
assert rank_page_table.dim() == 2
if out_size is None:
out_size = lora_b.size(-1)
batch_size = len(q_seqlens)
BLOCK_M = 32
BLOCK_R = _next_pow_of_2(max_rank)
if BLOCK_R < 16:
BLOCK_R = 16
BLOCK_HO = out_size
BLOCK_DMODEL = 64
num_warps = 4
grid = [batch_size, triton.cdiv(max_seq_len, BLOCK_M)]
output = xa.new_empty((xa.size(0), BLOCK_HO))
kernel_meta = _kernel_meta()
_acc_b_mm_kernel[grid](xa,
lora_b,
output,
q_start_loc,
q_seqlens,
adapter_ids,
scaling,
Rank_page_table=rank_page_table,
Rank_page_start=rank_page_start,
Ranks=ranks,
stride_xas=xa.stride(0),
stride_xar=xa.stride(1),
stride_os=output.stride(0),
stride_oh=output.stride(1),
stride_lbs=lora_b.stride(0),
stride_lbh=lora_b.stride(1),
stride_ptb=rank_page_table.stride(0),
BLOCK_M=BLOCK_M,
BLOCK_R=BLOCK_R,
BLOCK_HO=BLOCK_HO,
BLOCK_DMODEL=BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
return output
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
def _next_pow_of_2(x):
"""get next power of 2."""
return 1 << (x - 1).bit_length()
@triton.jit
def _x_a_mv_kernel(
X,
LoRA_A,
XA,
B_adapter_id,
Rank_page_table,
Rank_page_start,
Ranks,
stride_xs,
stride_xh,
stride_las,
stride_lah,
stride_xas,
stride_xar,
stride_ptb,
rank_step,
BLOCK_R: tl.constexpr,
BLOCK_H: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
"""xa mv kernel."""
cur_batch = tl.program_id(0)
r_off = tl.arange(0, BLOCK_R)
adapter_id = tl.load(B_adapter_id + cur_batch)
rank = tl.load(Ranks + adapter_id) // rank_step
page_start = tl.load(Rank_page_start + adapter_id)
page_table_off = adapter_id * stride_ptb + r_off + page_start
rank_mask = r_off < rank
page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
dm_off = tl.arange(0, BLOCK_DMODEL)
x_off = cur_batch * stride_xs
la_page_off = page_table * stride_las
acc = tl.zeros((BLOCK_R, ), dtype=tl.float32)
# compute acc
for start_h in range(0, BLOCK_H, BLOCK_DMODEL):
cur_dm_off = start_h + dm_off
h_mask = cur_dm_off < BLOCK_H
# load x
xh_off = cur_dm_off * stride_xh
x_mask = h_mask
x = tl.load(X + x_off + xh_off, mask=x_mask, other=0.0)
# load lora a
lah_off = cur_dm_off * stride_lah
la_mask = rank_mask[:, None] and h_mask[None, :]
la = tl.load(LoRA_A + la_page_off[:, None] + lah_off[None, :],
mask=la_mask,
other=0.0)
# compute
acc += tl.sum(x[None, :] * la, 1)
acc = acc.to(X.dtype.element_ty)
xa_off = cur_batch * stride_xas
tl.store(XA + xa_off + r_off * stride_xar, acc, mask=rank_mask)
@triton.jit
def _acc_b_mv_kernel(
XA,
LoRA_B,
Out,
B_adapter_id,
B_scaling,
Rank_page_table,
Rank_page_start,
Ranks,
stride_xas,
stride_xar,
stride_os,
stride_oh,
stride_lbs,
stride_lbh,
stride_ptb,
BLOCK_R: tl.constexpr,
BLOCK_HO: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
"""acc b mv kernel."""
cur_batch = tl.program_id(0)
r_off = tl.arange(0, BLOCK_R)
adapter_id = tl.load(B_adapter_id + cur_batch)
scaling = tl.load(B_scaling + cur_batch)
rank = tl.load(Ranks + adapter_id)
page_start = tl.load(Rank_page_start + adapter_id)
page_table_off = adapter_id * stride_ptb + r_off + page_start
rank_mask = r_off < rank
page_table = tl.load(Rank_page_table + page_table_off, mask=rank_mask)
dm_off = tl.arange(0, BLOCK_DMODEL)
lb_page_off = page_table * stride_lbs
o_off = cur_batch * stride_os
xa_off = cur_batch * stride_xas
acc = tl.load(XA + xa_off + r_off * stride_xar, mask=rank_mask, other=0.0)
# compute output
for start_h in range(0, BLOCK_HO, BLOCK_DMODEL):
cur_dm_off = start_h + dm_off
h_mask = cur_dm_off < BLOCK_HO
# load lora b
lbh_off = cur_dm_off * stride_lbh
lb_mask = rank_mask[:, None] and h_mask[None, :]
lb = tl.load(LoRA_B + lb_page_off[:, None] + lbh_off[None, :],
mask=lb_mask,
other=0)
# compute
out = tl.sum(acc[:, None] * lb, 0)
out = out.to(lb.dtype)
out = out * scaling
# store o
oh_off = cur_dm_off * stride_oh
tl.store(Out + o_off + oh_off, out, mask=h_mask)
@torch.inference_mode()
def mbgmv_a(x: Tensor,
lora_a: Tensor,
adapter_ids: Tensor,
rank_page_table: Tensor,
ranks: Tensor,
rank_page_start: Tensor,
max_rank: int,
rank_step: int = 1):
"""mbgmv_a."""
def _kernel_meta():
device = x.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
assert x.dim() == 2
assert lora_a.dim() == 2
assert rank_page_table.dim() == 2
head_size = x.size(-1)
batch_size = x.size(0)
max_rank = max_rank // rank_step
BLOCK_R = _next_pow_of_2(max_rank)
BLOCK_H = head_size
BLOCK_DMODEL = 512
num_warps = 4
grid = [batch_size]
xa = x.new_empty((x.size(0), BLOCK_R))
kernel_meta = _kernel_meta()
_x_a_mv_kernel[grid](x,
lora_a,
xa,
adapter_ids,
Rank_page_table=rank_page_table,
Rank_page_start=rank_page_start,
Ranks=ranks,
stride_xs=x.stride(0),
stride_xh=x.stride(1),
stride_las=lora_a.stride(0),
stride_lah=lora_a.stride(1),
stride_xas=xa.stride(0),
stride_xar=xa.stride(1),
stride_ptb=rank_page_table.stride(0),
rank_step=rank_step,
BLOCK_R=BLOCK_R,
BLOCK_H=BLOCK_H,
BLOCK_DMODEL=BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
return xa
@torch.inference_mode()
def mbgmv_b(xa: Tensor,
lora_b: Tensor,
adapter_ids: Tensor,
scaling: Tensor,
rank_page_table: Tensor,
ranks: Tensor,
rank_page_start: Tensor,
max_rank: int,
out_size: int = None):
"""mbgmv_b."""
def _kernel_meta():
device = xa.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
assert xa.dim() == 2
assert lora_b.dim() == 2
assert rank_page_table.dim() == 2
if out_size is None:
out_size = lora_b.size(-1)
batch_size = xa.size(0)
BLOCK_R = _next_pow_of_2(max_rank)
BLOCK_HO = out_size
BLOCK_DMODEL = 512
num_warps = 4
grid = [batch_size]
output = xa.new_empty((xa.size(0), BLOCK_HO))
kernel_meta = _kernel_meta()
_acc_b_mv_kernel[grid](xa,
lora_b,
output,
adapter_ids,
scaling,
Rank_page_table=rank_page_table,
Rank_page_start=rank_page_start,
Ranks=ranks,
stride_xas=xa.stride(0),
stride_xar=xa.stride(1),
stride_lbs=lora_b.stride(0),
stride_lbh=lora_b.stride(1),
stride_os=output.stride(0),
stride_oh=output.stride(1),
stride_ptb=rank_page_table.stride(0),
BLOCK_R=BLOCK_R,
BLOCK_HO=BLOCK_HO,
BLOCK_DMODEL=BLOCK_DMODEL,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
return output
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from triton.runtime.jit import get_cuda_stream
@triton.jit
def _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs,
stride_sb, stride_st, stride_ib, stride_it,
num_batchs, num_tokens, BLOCK: tl.constexpr,
BLOCK_N: tl.constexpr):
"""Kernel."""
batch_block_id = tl.program_id(0)
off = batch_block_id * BLOCK + tl.arange(0, BLOCK)
n_off = tl.arange(0, BLOCK_N)
off_mask = off < num_batchs
seed = tl.load(Seeds + off, mask=off_mask)
offset = tl.load(Offsets + off, mask=off_mask).to(tl.int32)
samp = tl.rand(seed, offset)[:, None]
acc = tl.zeros((BLOCK, ), dtype=tl.float32)
output = tl.load(Indices + off * stride_ib, mask=off_mask)
for b_idx in range(0, num_tokens, BLOCK_N):
s_off = b_idx + n_off
s_mask = off_mask[:, None] & (s_off[None, :] < num_tokens)
scores = tl.load(Scores + off[:, None] * stride_sb +
s_off[None, :] * stride_st,
mask=s_mask,
other=0.0).to(acc.dtype)
cum_scores = acc[:, None] + tl.cumsum(scores, 1)
acc += tl.sum(scores, 1)
pre_cum_scores = cum_scores - scores
valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)
found_mask = tl.sum(valid_mask, 1) > 0
valid_pos = b_idx + tl.argmax(valid_mask.to(tl.int32), 1)
indices = tl.load(Indices + off * stride_ib + valid_pos * stride_it,
mask=found_mask & off_mask,
other=-1)
output = tl.where(found_mask, indices, output)
tl.store(Outputs + off, output, mask=off_mask)
def multinomial_sampling(scores: torch.Tensor,
seeds: torch.LongTensor,
offsets: torch.LongTensor,
indices: torch.Tensor = None):
"""multinomial sampling."""
def __kernel_meta():
"""kernel meta."""
device = scores.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
assert scores.dim() == 2
batch_size, num_tokens = scores.size()
device = scores.device
if num_tokens == 1:
return torch.zeros_like(scores, dtype=torch.long)
if indices is None:
indices = torch.arange(num_tokens, device=device)
indices = indices.expand_as(scores)
assert indices.dim() == 2
assert indices.size() == scores.size()
outputs = indices[:, 0].clone()
BLOCK = 32
BLOCK_N = 64
grid = [triton.cdiv(batch_size, BLOCK)]
kernel_meta = __kernel_meta()
_multinomial_sampling_kernel[grid](scores,
seeds,
offsets,
indices,
outputs,
stride_sb=scores.stride(0),
stride_st=scores.stride(1),
stride_ib=indices.stride(0),
stride_it=indices.stride(1),
num_batchs=batch_size,
num_tokens=num_tokens,
BLOCK=BLOCK,
BLOCK_N=BLOCK_N,
**kernel_meta)
return outputs
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/ModelTC/lightllm
import torch
import triton
import triton.language as tl
from packaging import version
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
TRITON_VERSION = version.parse(triton.__version__)
assert TRITON_VERSION >= version.parse('2.1.0')
if TRITON_VERSION >= version.parse('2.2.0'):
@triton.jit
def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr,
BLOCK: tl.constexpr):
"""load block offsets."""
if num_sub_blocks > 1:
offs_sub = tl.arange(0, num_sub_blocks)
offs_n = tl.arange(0, BLOCK // num_sub_blocks)
ret = tl.load(
offset_ptr + block_id * num_sub_blocks +
offs_sub)[:, None] * BLOCK // num_sub_blocks + offs_n[None, :]
return tl.ravel(ret)
else:
offs_n = tl.arange(0, BLOCK)
return tl.load(offset_ptr + block_id) * BLOCK + offs_n
else:
@triton.jit
def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr,
BLOCK: tl.constexpr):
"""load block offsets triton<2.2.0."""
if num_sub_blocks > 1:
offs_sub = tl.arange(0, num_sub_blocks)
offs_n = tl.arange(0, BLOCK // num_sub_blocks)
ret = tl.load(offset_ptr + block_id * num_sub_blocks + offs_sub)[
None, :] * BLOCK // num_sub_blocks + offs_n[:, None]
return tl.ravel(ret)
else:
offs_n = tl.arange(0, BLOCK)
return tl.load(offset_ptr + block_id) * BLOCK + offs_n
@triton.jit
def _fwd_split_kernel(
Q,
K,
V,
sm_scale,
KV_seqlens,
Block_offsets,
Acc_out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_ok,
stride_obs,
stride_oh,
stride_od,
stride_boffb,
kv_group_num,
block_per_cta,
window_size: tl.constexpr,
num_sub_blocks: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""first step kernel of split k attention."""
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
split_k_id = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
q_seqlen = 1
kv_seqlen = tl.load(KV_seqlens + cur_batch)
history_len = kv_seqlen - q_seqlen
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = (cur_batch * stride_qbs + cur_head * stride_qh +
offs_d * stride_qd)
off_k = (cur_kv_head * stride_kh + offs_d[None, :] * stride_kd)
off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd)
q = tl.load(Q + off_q).to(tl.float32)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_offset_ptrs = Block_offsets + cur_batch * stride_boffb
# initialize pointer to m and l
m_i = -float('inf')
l_i = float(0)
acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)
kv_len_per_prog = block_per_cta * BLOCK_N
loop_start = kv_len_per_prog * split_k_id
loop_end = tl.minimum(loop_start + kv_len_per_prog, kv_seqlen)
# load block offset
# dirty
start_block_id = loop_start // BLOCK_N
if window_size > 0:
start_block_id = tl.maximum(history_len - window_size,
loop_start) // BLOCK_N
kv_min_loc = tl.maximum(history_len - window_size, 0)
b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
num_sub_blocks, BLOCK_N)
loop_start = start_block_id * BLOCK_N
for start_n in range(loop_start, loop_end, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
mask = (start_n + offs_n[:, None]) < kv_seqlen
# -- compute qk ----
k = tl.load(
k_ptrs + b_offset[:, None] * stride_kbs,
mask=mask,
other=0.0,
)
v = tl.load(
v_ptrs + b_offset[:, None] * stride_vbs,
mask=mask,
other=0.0,
)
# prefetch b_offset
if start_n + BLOCK_N < loop_end:
start_block_id += 1
b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
num_sub_blocks, BLOCK_N)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
# NOTE: inf - inf = nan, and nan will leads to error
qk_mask = history_len >= (start_n + offs_n)
if window_size > 0:
qk_mask = qk_mask and ((start_n + offs_n) >= kv_min_loc)
qk = tl.where(
qk_mask,
qk,
-float('inf'),
)
# -- compute p, m_i and l_i
m_i_new = tl.maximum(m_i, tl.max(qk, 0))
p = tl.exp(qk - m_i_new)
alpha = tl.exp(m_i - m_i_new)
l_i_new = alpha * l_i + tl.sum(p, 0)
# -- update output accumulator --
# scale acc
acc = acc * alpha
# update acc
p_new = p.to(v.dtype)
acc += tl.sum(p_new[:, None] * v, 0)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_acc = (cur_batch * stride_obs + split_k_id * stride_ok +
cur_head * stride_oh + offs_d * stride_od)
tl.store(Acc_out + off_acc, acc)
off_meta = (cur_batch * stride_obs + split_k_id * stride_ok +
cur_head * stride_oh + BLOCK_DMODEL)
tl.store(Acc_out + off_meta + tl.arange(0, 1), m_i)
tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i)
@triton.jit
def _reduce_split_kernel(
Acc,
Out,
stride_ak,
stride_abs,
stride_ah,
stride_ad,
stride_obs,
stride_oh,
stride_od,
SPLIT_K: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
):
"""second step kernel of split k attention."""
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
# initialize offsets
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_k = tl.arange(0, SPLIT_K)
offs_acc = (cur_batch * stride_abs + cur_head * stride_ah +
offs_k[:, None] * stride_ak + offs_d[None, :] * stride_ad)
offs_mi = (cur_batch * stride_abs + cur_head * stride_ah +
stride_ak * offs_k + BLOCK_DMODEL)
acc_k = tl.load(Acc + offs_acc)
m_k = tl.load(Acc + offs_mi)
l_k = tl.load(Acc + offs_mi + 1)
m_max = tl.max(m_k, 0)
alpha = tl.exp(m_k - m_max)
acc_k = acc_k * alpha[:, None]
l_k = l_k * alpha
acc = tl.sum(acc_k, 0)
l_sum = tl.sum(l_k, 0)
acc = acc / l_sum
out_offs = (cur_batch * stride_obs + cur_head * stride_oh +
offs_d * stride_od)
tl.store(Out + out_offs, acc)
def _get_convert_pv(nv_capability):
"""lazy load convert_pv."""
if nv_capability[0] >= 8:
@triton.jit
def convert_pv(p, v):
"""convert pv."""
p = p.to(v.dtype)
return p, v
else:
@triton.jit
def convert_pv(p, v):
"""convert pv."""
v = v.to(p.dtype)
return p, v
return convert_pv
_convert_pv = None
@triton.jit
def _fwd_kernel(
Q,
K,
V,
sm_scale,
Q_start_loc,
Q_seqlens,
KV_seqlens,
Block_offsets,
Out,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_boffb,
kv_group_num,
window_size: tl.constexpr,
num_sub_blocks: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
"""paged attention kernel."""
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
q_seqlen = tl.load(Q_seqlens + cur_batch)
kv_seqlen = tl.load(KV_seqlens + cur_batch)
q_start_loc = tl.load(Q_start_loc + cur_batch)
history_len = kv_seqlen - q_seqlen
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = ((q_start_loc + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
off_k = (cur_kv_head * stride_kh + offs_d[:, None] * stride_kd)
off_v = (cur_kv_head * stride_vh + offs_d[None, :] * stride_vd)
q = tl.load(Q + off_q, mask=offs_m[:, None] < q_seqlen, other=0.0)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_offset_ptrs = Block_offsets + cur_batch * stride_boffb
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
block_mask = tl.where(block_start_loc < q_seqlen, 1, 0)
# this is dirty
start_block_id = kv_seqlen - kv_seqlen
if window_size > 0:
start_block_id = tl.maximum(history_len - window_size, 0) // BLOCK_N
kv_min_loc = tl.maximum(history_len + offs_m - window_size, 0)
b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
num_sub_blocks, BLOCK_N)
kv_start_loc = start_block_id * BLOCK_N
for start_n in range(kv_start_loc, block_mask * kv_seqlen, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(
k_ptrs + b_offset[None, :] * stride_kbs,
mask=start_n + offs_n[None, :] < kv_seqlen,
other=0.0,
)
v = tl.load(
v_ptrs + b_offset[:, None] * stride_vbs,
mask=start_n + offs_n[:, None] < kv_seqlen,
other=0.0,
)
if start_n + BLOCK_N < kv_seqlen:
start_block_id = start_n // BLOCK_N + 1
b_offset = _load_block_offsets(block_offset_ptrs, start_block_id,
num_sub_blocks, BLOCK_N)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
# NOTE: inf - inf = nan, and nan will leads to error
qk_mask = (history_len + offs_m[:, None]) >= (start_n +
offs_n[None, :])
if window_size > 0:
qk_mask = qk_mask and (
(start_n + offs_n[None, :]) >= kv_min_loc[:, None])
qk = tl.where(
qk_mask,
qk,
float(-1e30),
)
# -- compute p, m_i and l_i
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
p = tl.exp(qk - m_i_new[:, None])
alpha = tl.exp(m_i - m_i_new)
l_i_new = alpha * l_i + tl.sum(p, 1)
# -- update output accumulator --
# scale acc
acc = acc * alpha[:, None]
# update acc
p, v = _convert_pv(p, v)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = ((q_start_loc + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs, acc, mask=offs_m[:, None] < q_seqlen)
@torch.inference_mode()
def paged_attention_fwd(
q: Tensor,
k: Tensor,
v: Tensor,
o: Tensor,
block_offsets: Tensor,
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_seqlens: Tensor,
max_seqlen: int,
window_size: int = -1,
):
"""Paged Attention forward.
Args:
q (Tensor): Query state.
k (Tensor): Key state caches.
v (Tensor): Value state caches.
o (Tensor): Output state.
block_offsets (Tensor): The block offset of key and value.
q_start_loc (Tensor): Start token location of each data in batch.
q_seqlens (Tensor): Query length for each data in batch.
kv_seqlens (Tensor): Key/Value length for each data in batch.
max_seqlen (int): The max input length.
BLOCK (int): The kernel block size.
"""
global _convert_pv
if _convert_pv is None:
nv_cap = torch.cuda.get_device_capability()
_convert_pv = _get_convert_pv(nv_cap)
def _kernel_meta():
"""kernel meta."""
device = q.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128, 256}
sm_scale = 1.0 / (Lq**0.5) # 计算scale系数
batch, head = q_seqlens.shape[0], q.shape[-2]
kv_group_num = q.shape[-2] // k.shape[-2]
num_warps = 4 if Lk <= 64 else 8
BLOCK = 64 if k.size(1) < 16 else k.size(1)
num_sub_blocks = BLOCK // k.size(1)
kernel_meta = _kernel_meta()
is_decoding = q.shape[-3] == q_seqlens.size(0)
if not is_decoding:
grid = (batch, head, triton.cdiv(max_seqlen, BLOCK))
_fwd_kernel[grid](q,
k,
v,
sm_scale,
q_start_loc,
q_seqlens,
kv_seqlens,
block_offsets,
o,
stride_qbs=q.stride(-3),
stride_qh=q.stride(-2),
stride_qd=q.stride(-1),
stride_kbs=k.stride(-3),
stride_kh=k.stride(-2),
stride_kd=k.stride(-1),
stride_vbs=v.stride(-3),
stride_vh=v.stride(-2),
stride_vd=v.stride(-1),
stride_obs=o.stride(-3),
stride_oh=o.stride(-2),
stride_od=o.stride(-1),
stride_boffb=block_offsets.stride(0),
kv_group_num=kv_group_num,
window_size=window_size,
num_sub_blocks=num_sub_blocks,
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
else:
SPLIT_K = 4
grid = (batch, head, SPLIT_K)
block_per_cta = triton.cdiv(block_offsets.size(-1), SPLIT_K)
acc = q.new_empty(batch, head, SPLIT_K, Lq + 2, dtype=torch.float32)
_fwd_split_kernel[grid](q,
k,
v,
sm_scale,
kv_seqlens,
block_offsets,
acc,
stride_qbs=q.stride(-3),
stride_qh=q.stride(-2),
stride_qd=q.stride(-1),
stride_kbs=k.stride(-3),
stride_kh=k.stride(-2),
stride_kd=k.stride(-1),
stride_vbs=v.stride(-3),
stride_vh=v.stride(-2),
stride_vd=v.stride(-1),
stride_ok=acc.stride(-2),
stride_obs=acc.stride(-4),
stride_oh=acc.stride(-3),
stride_od=acc.stride(-1),
stride_boffb=block_offsets.stride(0),
kv_group_num=kv_group_num,
block_per_cta=block_per_cta,
window_size=window_size,
num_sub_blocks=num_sub_blocks,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=4,
num_stages=1,
**kernel_meta)
grid = (batch, head)
_reduce_split_kernel[grid](acc,
o,
stride_ak=acc.stride(-2),
stride_abs=acc.stride(-4),
stride_ah=acc.stride(-3),
stride_ad=acc.stride(-1),
stride_obs=o.stride(-3),
stride_oh=o.stride(-2),
stride_od=o.stride(-1),
SPLIT_K=SPLIT_K,
BLOCK_DMODEL=Lk,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from triton.runtime.jit import get_cuda_stream
@triton.jit
def _rearange_all_gather_kernel(X, StartLoc, SeqLen, AdapterIds, Ranks, Out,
stride_x, stride_o, world_size,
BLOCK: tl.constexpr, BLOCK_P: tl.constexpr):
"""rearange all gather kernel."""
batch_id = tl.program_id(0)
block_id = tl.program_id(1)
start_loc = tl.load(StartLoc + batch_id) + block_id * BLOCK
seq_len = tl.load(SeqLen + batch_id)
if block_id * BLOCK >= seq_len:
return
block_off = start_loc + tl.arange(0, BLOCK)
block_mask = block_id * BLOCK + tl.arange(0, BLOCK) < seq_len
adapter_id = tl.load(AdapterIds + batch_id)
rank = tl.load(Ranks + adapter_id)
prank = rank // world_size
p_off = tl.arange(0, BLOCK_P)
for p_id in range(world_size):
ip_off = p_id * BLOCK_P + p_off
i_mask = block_mask[:, None] and (p_off < prank)[None, :]
i_off = block_off[:, None] * stride_x + ip_off[None, :]
x = tl.load(X + i_off, mask=i_mask)
op_off = p_id * prank + p_off
o_mask = i_mask
o_off = block_off[:, None] * stride_o + op_off[None, :]
tl.store(Out + o_off, x, mask=o_mask)
@triton.jit
def _rearange_all_gather_decoding_kernel(X, AdapterIds, Ranks, Out, stride_x,
stride_o, world_size, seq_len,
BLOCK: tl.constexpr,
BLOCK_P: tl.constexpr):
"""rearange all gather kernel."""
block_id = tl.program_id(0)
block_off = block_id * BLOCK + tl.arange(0, BLOCK)
block_mask = block_off < seq_len
adapter_ids = tl.load(AdapterIds + block_off, mask=block_mask)
ranks = tl.load(Ranks + adapter_ids)
pranks = ranks // world_size
p_off = tl.arange(0, BLOCK_P)
for p_id in range(world_size):
ip_off = p_id * BLOCK_P + p_off
i_mask = block_mask[:, None] and (p_off[None, :] < pranks[:, None])
i_off = block_off[:, None] * stride_x + ip_off[None, :]
x = tl.load(X + i_off, mask=i_mask)
op_off = p_id * pranks[:, None] + p_off[None, :]
o_mask = i_mask
o_off = block_off[:, None] * stride_o + op_off
tl.store(Out + o_off, x, mask=o_mask)
def rearange_all_gather(x: torch.Tensor,
b_start_loc: torch.Tensor,
b_seq_lens: torch.Tensor,
adapter_ids: torch.LongTensor,
ranks: torch.Tensor,
world_size: int,
max_seq_len: int,
output: torch.Tensor = None):
"""rearange all gather."""
def _kernel_meta():
device = x.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
max_rank = x.size(1)
batch_size = len(b_seq_lens)
partition_size = max_rank // world_size
if output is None:
output = torch.empty_like(x)
num_warps = 4
kernel_meta = _kernel_meta()
is_decoding = batch_size == x.size(0)
if not is_decoding:
BLOCK = 128
BLOCK_P = partition_size
grid = (batch_size, triton.cdiv(max_seq_len, BLOCK))
_rearange_all_gather_kernel[grid](x,
b_start_loc,
b_seq_lens,
adapter_ids,
ranks,
output,
stride_x=x.stride(0),
stride_o=output.stride(0),
world_size=world_size,
BLOCK=BLOCK,
BLOCK_P=BLOCK_P,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
else:
BLOCK = 64
BLOCK_P = partition_size
seq_len = x.size(0)
grid = (triton.cdiv(seq_len, BLOCK), )
_rearange_all_gather_decoding_kernel[grid](x,
adapter_ids,
ranks,
output,
stride_x=x.stride(0),
stride_o=output.stride(0),
world_size=world_size,
seq_len=seq_len,
BLOCK=BLOCK,
BLOCK_P=BLOCK_P,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
return output
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
assert triton.__version__ >= '2.1.0'
# bugfix from https://gist.github.com/chu-tianxiang/4307937fd94b49c75b61a6967716bae9#file-rerope-py # noqa: E501
@triton.jit
def _rerope_fwd_kernel(
Q1,
Q2,
K1,
K2,
V,
sm_scale,
# L,
Out,
stride_qz,
stride_qh,
stride_qm,
stride_qk,
stride_kz,
stride_kh,
stride_kn,
stride_kk,
stride_vz,
stride_vh,
stride_vk,
stride_vn,
stride_oz,
stride_oh,
stride_om,
stride_on,
Z,
H,
N_CTX,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
IS_CAUSAL: tl.constexpr,
WINDOW: tl.constexpr,
):
"""rerope attention triton kernel."""
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
q_offset = off_hz * stride_qh
kv_offset = off_hz * stride_kh
Q1_block_ptr = tl.make_block_ptr(base=Q1 + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
Q2_block_ptr = tl.make_block_ptr(base=Q2 + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
K1_block_ptr = tl.make_block_ptr(base=K1 + kv_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1))
K2_block_ptr = tl.make_block_ptr(base=K2 + kv_offset,
shape=(BLOCK_DMODEL, N_CTX),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(BLOCK_DMODEL, BLOCK_N),
order=(0, 1))
V_block_ptr = tl.make_block_ptr(base=V + kv_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_vk, stride_vn),
offsets=(0, 0),
block_shape=(BLOCK_N, BLOCK_DMODEL),
order=(1, 0))
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale = sm_scale * 1.44269504
# load q: it will stay in SRAM throughout
q1 = tl.load(Q1_block_ptr, boundary_check=(0, 1))
dtype = q1.dtype
q1 = (q1 * qk_scale).to(dtype)
q2 = tl.load(Q2_block_ptr, boundary_check=(0, 1))
q2 = (q2 * qk_scale).to(dtype)
# loop over k, v and update accumulator
lo = 0
hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX
for start_n in range(lo, hi, BLOCK_N):
# -- compute qk ---
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
if IS_CAUSAL:
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float('-inf'))
if start_n <= start_m * BLOCK_M - WINDOW - BLOCK_N or start_n >= (
start_m + 1) * BLOCK_M + WINDOW:
k2 = tl.load(K2_block_ptr)
v = tl.load(V_block_ptr)
qk += tl.dot(q2, k2, out_dtype=tl.float32)
elif start_n > (
start_m + 1
) * BLOCK_M - WINDOW and start_n < start_m * BLOCK_M + WINDOW - BLOCK_N: # noqa: E501
k1 = tl.load(K1_block_ptr)
v = tl.load(V_block_ptr)
qk += tl.dot(q1, k1, out_dtype=tl.float32)
else:
k1 = tl.load(K1_block_ptr)
k2 = tl.load(K2_block_ptr)
v = tl.load(V_block_ptr)
qk1 = tl.dot(q1, k1, out_dtype=tl.float32)
qk2 = tl.dot(q2, k2, out_dtype=tl.float32)
qk += tl.where(
tl.abs(offs_m[:, None] - (start_n + offs_n[None, :])) < WINDOW,
qk1, qk2)
# -- compute scaling constant ---
m_i_new = tl.maximum(m_i, tl.max(qk, 1))
alpha = tl.math.exp2(m_i - m_i_new)
p = tl.math.exp2(qk - m_i_new[:, None])
# -- scale and update acc --
acc_scale = l_i * 0 + alpha # workaround some compiler bug
acc *= acc_scale[:, None]
acc += tl.dot(p, v.to(tl.float32))
# -- update m_i and l_i --
l_i = l_i * alpha + tl.sum(p, 1)
m_i = m_i_new
# update pointers
K1_block_ptr = tl.advance(K1_block_ptr, (0, BLOCK_N))
K2_block_ptr = tl.advance(K2_block_ptr, (0, BLOCK_N))
V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
# write back l and m
acc = acc / l_i[:, None]
# debug softmax output
# l_ptrs = L + off_hz * N_CTX + offs_m
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr = tl.make_block_ptr(base=Out + q_offset,
shape=(N_CTX, BLOCK_DMODEL),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_DMODEL),
order=(1, 0))
tl.store(O_block_ptr, acc.to(dtype), boundary_check=(0, 1))
def rerope_attention_fwd(q1,
q2,
k1,
k2,
v,
causal,
sm_scale,
window,
BLOCK_M=64):
"""rerope attention forward."""
# shape constraints
Lq, Lk, Lv = q1.shape[-1], k1.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
o = torch.empty_like(q1)
BLOCK_N = 64 if Lk <= 64 else 32
num_stages = 4 if Lk <= 64 else 3
num_warps = 4
grid = (triton.cdiv(q1.shape[2], BLOCK_M), q1.shape[0] * q1.shape[1], 1)
# L = torch.empty((q1.shape[0] * q1.shape[1], q1.shape[2]),
# device=q1.device,
# dtype=torch.float32)
_rerope_fwd_kernel[grid](
q1,
q2,
k1,
k2,
v,
sm_scale,
# L,
o,
q1.stride(0),
q1.stride(1),
q1.stride(2),
q1.stride(3),
k1.stride(0),
k1.stride(1),
k1.stride(2),
k1.stride(3),
v.stride(0),
v.stride(1),
v.stride(2),
v.stride(3),
o.stride(0),
o.stride(1),
o.stride(2),
o.stride(3),
q1.shape[0],
q1.shape[1],
q1.shape[2],
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=Lk,
IS_CAUSAL=causal,
WINDOW=window,
num_warps=num_warps,
num_stages=num_stages)
return o
if __name__ == '__main__':
def test_rerope():
import torch.utils.benchmark as benchmark
Z = 1
H = 40
N_CTX = 2176
D_HEAD = 128
WINDOW = 512
sm_scale = 0.0883883
def torch_attention(q1, q2, k1, k2, v, causal, sm_scale, window):
# reference implementation
M = torch.tril(torch.ones((N_CTX, N_CTX), device='cuda'))
p1 = torch.matmul(q1, k1.transpose(2, 3)) * sm_scale
p2 = torch.matmul(q2, k2.transpose(2, 3)) * sm_scale
if causal:
p1[:, :, M == 0] = float('-inf')
p2[:, :, M == 0] = float('-inf')
x = torch.arange(N_CTX, dtype=torch.int, device='cuda')
M2 = ((x[:, None] - x[None, :]).abs() < window)[None, None, :]
p = torch.where(M2, p1, p2)
p = torch.softmax(p.float(), dim=-1).half()
ref_out = torch.matmul(p, v)
return ref_out
def torch_attention2(query_states1, query_states2, key_states1,
key_states2, value_states, causal, sm_scale,
window):
query_states1 = query_states1.squeeze(0).contiguous()
query_states2 = query_states2.squeeze(0).contiguous()
key_states1 = key_states1.squeeze(0).contiguous()
key_states2 = key_states2.squeeze(0).contiguous()
value_states = value_states.squeeze(0).contiguous()
attn_weights1 = torch.matmul(
query_states1, key_states1.transpose(1, 2)) * sm_scale
attn_weights2 = torch.matmul(
query_states2, key_states2.transpose(1, 2)) * sm_scale
position_ids = torch.arange(
query_states1.shape[1],
device=query_states1.device).unsqueeze(0)
rectified_mask = (position_ids[:, -N_CTX:, None] -
position_ids[:, None]).abs() < window
attn_weights = torch.where(rectified_mask, attn_weights1,
attn_weights2)
if causal:
tgt_len = attn_weights.shape[-1]
dtype = attn_weights.dtype
device = attn_weights.device
mask = torch.full((tgt_len, tgt_len),
torch.finfo(dtype).min,
device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(
mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
attn_weights = attn_weights + mask
# upcast attention to fp32
attn_weights = torch.nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states1.dtype)
attn_output = torch.matmul(attn_weights, value_states)
return attn_output
q1 = torch.empty((Z, H, N_CTX, D_HEAD),
dtype=torch.float16,
device='cuda').normal_(mean=0., std=0.5).contiguous()
q2 = torch.empty((Z, H, N_CTX, D_HEAD),
dtype=torch.float16,
device='cuda').normal_(mean=0., std=0.5).contiguous()
k1 = torch.empty((Z, H, N_CTX, D_HEAD),
dtype=torch.float16,
device='cuda').normal_(mean=0., std=0.5).contiguous()
k2 = torch.empty((Z, H, N_CTX, D_HEAD),
dtype=torch.float16,
device='cuda').normal_(mean=0., std=0.5).contiguous()
v = torch.empty((Z, H, N_CTX, D_HEAD),
dtype=torch.float16,
device='cuda').normal_(mean=0., std=0.5).contiguous()
# q1 = torch.load('/workspace/GitProjects/lmdeploy/q1.pt',
# map_location='cuda').contiguous()
# q2 = torch.load('/workspace/GitProjects/lmdeploy/q2.pt',
# map_location='cuda').contiguous()
# k1 = torch.load('/workspace/GitProjects/lmdeploy/k1.pt',
# map_location='cuda').contiguous()
# k2 = torch.load('/workspace/GitProjects/lmdeploy/k2.pt',
# map_location='cuda').contiguous()
# v = torch.load('/workspace/GitProjects/lmdeploy/v.pt',
# map_location='cuda').contiguous()
torch_output = torch_attention(q1, q2, k1, k2, v, True, sm_scale,
WINDOW)
torch_output2 = torch_attention2(q1, q2, k1, k2, v, True, sm_scale,
WINDOW)
assert torch.allclose(torch_output, torch_output2, atol=1e-2, rtol=0)
for _ in range(100):
triton_output = rerope_attention_fwd(q1, q2, k1, k2, v, True,
sm_scale, WINDOW)
assert torch.allclose(
torch_output, triton_output, atol=2e-2, rtol=0) is True
def f(fn, q1, q2, k1, k2, v, sm_scale, window):
fn(q1, q2, k1, k2, v, True, sm_scale, window)
t0 = benchmark.Timer(stmt='f(fn, q1, q2, k1, k2, v, sm_scale, window)',
globals={
'f': f,
'fn': torch_attention2,
'q1': q1,
'q2': q2,
'k1': k1,
'k2': k2,
'v': v,
'sm_scale': sm_scale,
'window': WINDOW
},
num_threads=torch.get_num_threads())
print(t0.timeit(20))
import time
begin = time.time()
LOOP = 100
for _ in range(LOOP):
rerope_attention_fwd(q1, q2, k1, k2, v, True, sm_scale, WINDOW)
print(time.time() - begin)
test_rerope()
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import triton
import triton.language as tl
from torch import Tensor
from triton.runtime.jit import get_cuda_stream
@triton.jit
def rms_norm_kernel(input, weight, output, input_row_stride, n_cols, eps,
N_COLS: tl.constexpr, BLOCK_N: tl.constexpr):
"""rms norm kernel."""
prog_id = tl.program_id(0)
offsets = tl.arange(0, BLOCK_N)
w = tl.load(weight + offsets, mask=offsets < n_cols)
x_ptr = input + prog_id * input_row_stride
x = tl.load(x_ptr + offsets, mask=offsets < n_cols)
xf = x.to(tl.float32)
var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
out = xf / tl.sqrt(var + eps)
out = (w * out).to(x.dtype)
out_ptr = output + prog_id * input_row_stride
tl.store(out_ptr + offsets, out, mask=offsets < n_cols)
@torch.inference_mode()
def rms_norm(hidden_states: Tensor, weight: Tensor, eps: float = 1e-6):
"""rms norm."""
def _kernel_meta():
device = hidden_states.device
device_idx = device.index
device_type = device.type
stream = get_cuda_stream(device_idx)
return dict(device=device, device_type=device_type, stream=stream)
feat_size = weight.shape[0]
seq_len = hidden_states.numel() // hidden_states.size(-1)
input_stride = hidden_states.stride(-2)
BLOCK_N = triton.next_power_of_2(feat_size)
out = torch.empty_like(hidden_states)
kernel_meta = _kernel_meta()
grid = (seq_len, )
rms_norm_kernel[grid](hidden_states,
weight,
out,
input_stride,
feat_size,
eps,
feat_size,
BLOCK_N,
num_warps=4,
num_stages=2,
**kernel_meta)
return out
if __name__ == '__main__':
import time
def torch_forward(hidden_states, weight, variance_epsilon=1e-6):
"""pytorch forward."""
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
variance_epsilon)
return weight * hidden_states.to(input_dtype)
def test_rms_norm(bsz, ctx_len, feat_len, dtype):
"""test rms norm."""
input = torch.empty((bsz, ctx_len, feat_len),
dtype=dtype,
device='cuda').normal_(mean=0.,
std=0.5).contiguous()
weight = torch.empty((feat_len), dtype=dtype,
device='cuda').normal_(mean=0.,
std=0.5).contiguous()
triton_output = rms_norm(hidden_states=input, weight=weight)
torch_output = torch_forward(hidden_states=input, weight=weight)
assert torch.allclose(torch_output, triton_output, atol=1e-2, rtol=0)
N_REPEATS = 20
t0 = time.time()
for _ in range(N_REPEATS):
torch_forward(hidden_states=input, weight=weight)
t1 = time.time()
for _ in range(N_REPEATS):
rms_norm(hidden_states=input, weight=weight)
t2 = time.time()
torch_cost = (t1 - t0) / N_REPEATS * 1000
triton_cost = (t2 - t1) / N_REPEATS * 1000
print(
'input {} weight {} dtype {}\n torch {:.3f} triton {:.3f} (ms)\n'.
format(input.shape, weight.shape, dtype, torch_cost, triton_cost))
test_rms_norm(1, 8128, 5120, torch.float16)
test_rms_norm(1, 8128, 5120, torch.float32)
test_rms_norm(1, 992, 128, torch.float16)
test_rms_norm(1, 65537, 128, torch.float32)
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
def per_channel_quant(x, n_bits, dtype):
"""Quantize the input tensor 'x' channel-wise using the given number of
bits.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be a
2-dimensional tensor.
n_bits (int): The number of bits to use for quantization.
dtype (torch.dtype): The data type to which the quantized tensor should
be converted.
Returns:
tuple: A tuple containing two items -- the quantized tensor and
the scale used for quantization.
"""
assert x.ndim == 2
x = x.to(torch.float32)
x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
q_max = 2**(n_bits - 1) - 1
q_min = -2**(n_bits - 1)
scale = x_absmax / (2**(n_bits - 1) - 1)
x_q = torch.round(x / scale).clamp(q_min, q_max).to(dtype)
return x_q, scale
@triton.autotune(
configs=[
triton.Config({
'BLOCK_M': 16,
'BLOCK_N': 128,
'BLOCK_K': 256,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4)
],
key=['M', 'N', 'K'],
)
@triton.jit
def _linear(
A,
B,
C,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
rms_scale_ptr,
linear_scale_ptr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B`, and store the result in output
tensor `C`.
The function applies auto-tuning for optimal performance and uses Just-in-
Time compilation.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = accumulator.to(tl.float32)
rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]
linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]
c = c * rms_scale * linear_scale
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@triton.autotune(
configs=[
triton.Config({
'BLOCK_M': 16,
'BLOCK_N': 128,
'BLOCK_K': 256,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 32,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 64,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 64,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4),
triton.Config({
'BLOCK_M': 128,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
num_stages=4,
num_warps=4)
],
key=['M', 'N', 'K'],
)
@triton.jit
def _linear_add(
A,
B,
C,
residual_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
rms_scale_ptr,
linear_scale_ptr,
):
"""Triton-accelerated function used to perform a linear operation (dot
product) on input tensors `A` and `B`, with addition of residual.
The result is stored in tensor `C`. The function applies auto-tuning for
optimal performance and uses Just-in-Time compilation.
"""
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = accumulator.to(tl.float32)
rms_scale = tl.load(rms_scale_ptr + offs_am)[:, None]
linear_scale = tl.load(linear_scale_ptr + offs_bn)[None, :]
c = c * rms_scale * linear_scale
c = c.to(residual_ptr.dtype.element_ty)
offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
residual_ptrs = (residual_ptr + stride_cm * offs_cm[:, None] +
stride_cn * offs_cn[None, :])
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
residual = tl.load(residual_ptrs, mask=c_mask, other=0.)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(c_ptrs, c + residual, mask=c_mask)
def matmul_kernel_dynamic_quant(a,
b,
rms_scale,
linear_scale,
residual=None,
bias=None,
output_dtype=torch.float16):
"""This function performs matrix multiplication with dynamic quantization.
It takes two input tensors `a` and `b`, scales them with `rms_scale` and
`linear_scale`, and optionally adds a `residual` tensor and a `bias`. The
output is returned in the specified `output_dtype`.
"""
assert a.shape[-1] == b.shape[-1]
assert b.ndim == 2 and b.is_contiguous()
b = b.t() # (K, N)
M = a.numel() // a.shape[-1]
K, N = b.shape
c_shape = a.shape[:-1] + (N, )
if residual is not None:
assert residual.shape == c_shape
assert residual.is_contiguous()
c = torch.empty(c_shape, device=a.device, dtype=output_dtype)
def grid(META):
return (triton.cdiv(M, META['BLOCK_M']) *
triton.cdiv(N, META['BLOCK_N']), )
if residual is not None:
_linear_add[grid](a,
b,
c,
residual,
M,
N,
K,
a.stride(-2),
a.stride(-1),
b.stride(0),
b.stride(1),
c.stride(-2),
c.stride(-1),
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale)
else:
_linear[grid](a,
b,
c,
M,
N,
K,
a.stride(-2),
a.stride(-1),
b.stride(0),
b.stride(1),
c.stride(-2),
c.stride(-1),
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale)
if bias is not None:
c += bias
return c
@triton.jit
def _per_token_quant_int8(
y_ptr,
y_q_ptr,
y_s_ptr,
y_stride,
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK: tl.constexpr,
):
"""A Triton-accelerated function to perform per-token quantization on a
tensor.
This function converts the tensor values into signed 8-bit integers.
"""
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
y_ptr += row * y_stride
y_q_ptr += row * y_stride
y_s_ptr += row
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / 127
y_q = tl.maximum(tl.minimum(tl.math.round(y / y_s), 127), -128).to(tl.int8)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
def per_token_quant_int8(x, eps):
"""Function to perform per-token quantization on an input tensor `x`.
It converts the tensor values into signed 8-bit integers and returns the
quantized tensor along with the scaling factor used for quantization.
"""
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_s = torch.empty(x.shape[:-1] + (1, ),
device=x.device,
dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
# enqueue kernel
_per_token_quant_int8[(M, )](x,
x_q,
x_s,
x.stride(-2),
N,
eps,
BLOCK=BLOCK,
num_warps=num_warps)
return x_q, x_s
@triton.jit
def _rms_norm_fwd_fused_dynamic_symmetric(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
Scale, # pointer to the scales of the output activation
stride, # how much to increase the pointer when moving by 1 row
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
"""A Triton kernel that calculates Root Mean Square (RMS) normalization
with fused dynamic symmetric quantization."""
row = tl.program_id(0)
Y += row * stride
X += row * stride
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
cols = tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)
cols = tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask)
x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
x_hat = x * rstd
y = x_hat * w
scale = tl.max(tl.abs(y)).to(tl.float32) / 127
tl.store(Scale + row, scale)
y = tl.math.round(y / scale)
y = tl.minimum(y, 127)
y = tl.maximum(y, -128)
tl.store(Y + cols, y, mask=mask)
def rms_norm_dynamic_quant(x, w, eps):
"""Performs RMS normalization with dynamic quantization.
The function reshapes the input tensor `x`, creates an empty tensor `y`
with the same shape as `x`, and calculates RMS normalization on the
reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`.
"""
x_arg = x.reshape(-1, x.shape[-1])
y = torch.empty_like(x, dtype=torch.int8)
M, K = x_arg.shape
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K))
if K > BLOCK_SIZE:
raise RuntimeError(
"This rms norm doesn't support feature dim >= 64KB.")
num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
scale = torch.empty(x.shape[:-1] + (1, ),
dtype=torch.float32,
device=x.device)
_rms_norm_fwd_fused_dynamic_symmetric[(M, )](
x_arg,
y,
w,
scale,
x_arg.stride(0),
K,
eps,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=num_warps,
)
return y, scale
def test_rms_and_linear(x,
rms_weight,
linear_weight,
dtype=torch.float16,
eps=1e-5):
"""Test quantized rms norm and quantized linear layer."""
def rms_norm_torch(x, w, eps):
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
return w * x
def linear_torch(x, b):
return F.linear(x, b)
linear_weight_quant, linear_scale = per_channel_quant(
linear_weight, 8, torch.int8)
rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps)
assert rms_out.shape == x.shape and rms_scale.shape[:-1] == x.shape[:-1]
linear_out = matmul_kernel_dynamic_quant(rms_out,
linear_weight_quant,
rms_scale,
linear_scale,
output_dtype=dtype)
rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
linear_out_torch = linear_torch(rms_out_torch, linear_weight)
print(f'linear_out.abs().mean() = {linear_out.abs().mean()}')
print(f'linear_out_torch.abs().mean() = {linear_out_torch.abs().mean()}')
print('perchannel error: ', (linear_out - linear_out_torch).abs().mean())
cos = torch.nn.CosineSimilarity(0)
print(
'Output cos',
cos(linear_out.flatten().to(torch.float32),
linear_out_torch.flatten().to(torch.float32)))
def test_per_token_quant(x, eps):
"""Test per-token quantization."""
def per_token_quant_int8_torch(x, eps):
_absmax = torch.clamp(x.abs().max(dim=-1, keepdim=True)[0], min=eps)
x_s = _absmax / 127
x_q = torch.clamp((x / x_s).round(), min=-128, max=127)
return x_q, x_s
x_q, x_s = per_token_quant_int8(x, eps)
x_q_torch, x_s_torch = per_token_quant_int8_torch(x, eps)
assert x_q.shape == x_q_torch.shape and x_s.shape == x_s_torch.shape
cos = torch.nn.CosineSimilarity(0)
print(
'x_q cos',
cos(x_q.flatten().to(torch.float32),
x_q_torch.flatten().to(torch.float32)))
print(
'x_s cos',
cos(x_s.flatten().to(torch.float32),
x_s_torch.flatten().to(torch.float32)))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M'],
x_vals=[1, 16, 32, 64, 128, 256] + [512 * i * 2 for i in range(1, 17)],
line_arg='provider',
line_vals=['int8_dynamic_triton_op', 'float_torch'],
line_names=['int8_dynamic_triton_op', 'float_torch'],
styles=[('blue', '-'), ('green', '-'), ('orange', '-'),
('yellow', '-'), ('yellow', '-')],
ylabel='GB/s',
plot_name='forward',
args={
'dtype': torch.float16,
}))
def bench_rms_and_linear(M, dtype, provider, eps=1e-5, device='cuda'):
def rms_norm_torch(x, w, eps):
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
return w * x
def linear_torch(x, b):
return F.linear(x, b)
N = 4096
K = 4096
x_shape = (M, K)
rms_w_shape = (x_shape[-1], )
rms_weight = torch.randn(rms_w_shape,
dtype=dtype,
device='cuda',
requires_grad=True)
x = torch.randn(x_shape, dtype=dtype, device='cuda')
linear_weight = torch.randn((N, K),
dtype=dtype,
device='cuda',
requires_grad=True)
linear_weight_quant, linear_scale = per_channel_quant(
linear_weight, 8, torch.int8)
alpha = max(x.max().abs(), x.min().abs())
rms_scale = alpha / 127
if provider == 'int8_dynamic_triton_op':
rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps)
def y_fwd():
matmul_kernel_dynamic_quant(rms_out,
linear_weight_quant,
rms_scale,
linear_scale,
output_dtype=dtype)
elif provider == 'float_torch':
rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
def y_fwd():
linear_torch(rms_out_torch, linear_weight)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd,
quantiles=quantiles,
rep=500)
return ms, max_ms, min_ms
if __name__ == '__main__':
torch.manual_seed(0)
dtype = torch.float16
# test (bs, seq_len, dim) x (dim, out_dim)
x = torch.randn((2, 2048, 4096), dtype=dtype, device='cuda')
rms_weight = torch.randn((4096, ),
dtype=dtype,
device='cuda',
requires_grad=True)
linear_weight = torch.randn((11008, 4096),
dtype=dtype,
device='cuda',
requires_grad=True)
test_rms_and_linear(x, rms_weight, linear_weight)
# test (M, K) x (K, N)
x = torch.randn((4, 4096), dtype=dtype, device='cuda')
rms_weight = torch.randn((4096, ),
dtype=dtype,
device='cuda',
requires_grad=True)
linear_weight = torch.randn((2048, 4096),
dtype=dtype,
device='cuda',
requires_grad=True)
test_rms_and_linear(x, rms_weight, linear_weight)
# test per-token quant
x = torch.randn((4, 2048, 4096), dtype=dtype, device='cuda')
eps = 1e-7
test_per_token_quant(x, eps)
bench_rms_and_linear.run(print_data=True)
# Copyright (c) OpenMMLab. All rights reserved.
import enum
import time
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List
import torch
from torch import Tensor
from lmdeploy.messages import EngineGenerationConfig
from lmdeploy.utils import get_logger
from .block import LogicalTokenBlocks
logger = get_logger('lmdeploy')
@dataclass
class SamplingParam:
"""Sampling parameter."""
top_p: float = 1.0
top_k: int = 1
temperature: float = 0.8
repetition_penalty: float = 1.0
ignore_eos: bool = False
random_seed: int = None
stop_words: List[int] = field(default_factory=list)
bad_words: List[int] = field(default_factory=list)
max_new_tokens: int = 512
min_new_tokens: int = 0
def logical_sampling_param(self):
"""create a SamplingParam for logical sampling."""
return SamplingParam(top_p=self.top_p,
top_k=self.top_k,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
ignore_eos=self.ignore_eos,
random_seed=self.random_seed,
bad_words=self.bad_words)
@classmethod
def from_gen_config(self, gen_config: EngineGenerationConfig):
"""from gen config."""
min_new_tokens = gen_config.min_new_tokens or 0
stop_words = gen_config.stop_words or []
bad_words = gen_config.bad_words or []
if gen_config.ignore_eos:
bad_words += stop_words
top_k = gen_config.top_k
top_p = gen_config.top_p
temperature = gen_config.temperature
repetition_penalty = gen_config.repetition_penalty
max_new_tokens = gen_config.max_new_tokens
if top_k <= 0:
logger.warning('`top_k` has to be a strictly'
f' positive value, but is {top_k}')
top_k = 1
if top_p < 0 or top_p > 1.0:
logger.warning('`top_p` has to be a float > 0 and < 1'
f' but is {top_p}')
top_p = 1.0
if temperature <= 0:
logger.warning('`temperature` has to be a strictly'
f' positive value, but is {temperature}')
temperature = 1.0
if repetition_penalty <= 0:
logger.warning('`repetition_penalty` has to be a strictly'
f' positive value, but is {repetition_penalty}')
repetition_penalty = 1.0
if max_new_tokens < 0:
logger.warning('`max_new_tokens` has to be a strictly'
f' positive value, but is {max_new_tokens}')
max_new_tokens = 512
if min_new_tokens < 0 or min_new_tokens > max_new_tokens:
logger.warning('`min_new_tokens` has to be '
'a int >=0 and <= `max_new_tokens`,'
f' but is {min_new_tokens}')
min_new_tokens = 0
return SamplingParam(top_p=top_p,
top_k=top_k,
temperature=temperature,
repetition_penalty=repetition_penalty,
ignore_eos=gen_config.ignore_eos,
random_seed=gen_config.random_seed,
stop_words=stop_words,
bad_words=bad_words,
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens)
class MessageStatus(enum.Enum):
"""Status of a sequence."""
WAITING = enum.auto()
RUNNING = enum.auto()
STOPPED = enum.auto()
ENDED = enum.auto()
ABORTED = enum.auto()
_SEQ_COUNT = 0
def _new_msg_id():
"""get a new message id."""
global _SEQ_COUNT
seq_id = _SEQ_COUNT
_SEQ_COUNT += 1
return seq_id
class SchedulerSession:
"""Scheduler session."""
def __init__(self, session_id: int, block_size: int) -> None:
self.session_id = session_id
self.block_size = block_size
self.status: MessageStatus = MessageStatus.RUNNING
self.sequences: Dict[int, SchedulerSequence] = dict()
def add_sequence(self,
token_ids: Tensor,
sampling_param: SamplingParam = None,
adapter_name: str = None,
return_logits: bool = False) -> 'SchedulerSequence':
"""Add a new message."""
if not isinstance(token_ids, Tensor):
token_ids = torch.tensor(token_ids)
if token_ids.dim() == 0:
token_ids = token_ids.unsqueeze(0)
if sampling_param is None:
sampling_param = SamplingParam()
seq = SchedulerSequence(seq_id=_new_msg_id(),
token_ids=token_ids,
session=self,
block_size=self.block_size,
status=MessageStatus.WAITING,
num_new_tokens=0,
sampling_param=sampling_param,
adapter_name=adapter_name,
arrive_time=time.time(),
return_logits=return_logits)
self.sequences[seq.seq_id] = seq
return seq
def fork_sequence(
self,
token_ids: Tensor,
seq: 'SchedulerSequence',
sampling_param: SamplingParam = None) -> 'SchedulerSequence':
"""Fork a new message from exist message."""
if sampling_param is None:
sampling_param = deepcopy(seq.sampling_param)
if not isinstance(token_ids, Tensor):
token_ids = torch.tensor(token_ids)
if token_ids.dim() == 0:
token_ids = token_ids.unsqueeze(0)
assert seq.session == self
new_msg = SchedulerSequence(
seq_id=_new_msg_id(),
token_ids=token_ids,
session=self,
block_size=self.block_size,
history_token_ids=seq.history_token_ids.copy(),
num_new_tokens=0,
sampling_param=sampling_param,
status=seq.status,
logical_blocks=seq.logical_blocks.clone(),
adapter_name=seq.adapter_name,
arrive_time=time.time(),
meta=deepcopy(seq.meta),
return_logits=seq.return_logits,
random_offsets=seq.random_offsets + 1)
self.sequences[new_msg.seq_id] = new_msg
return new_msg
@dataclass
class SchedulerSequence:
"""Scheduler message."""
seq_id: int
token_ids: Tensor
session: SchedulerSession
block_size: int
history_token_ids: list = field(default_factory=list)
num_new_tokens: int = 0
sampling_param: SamplingParam = field(default_factory=SamplingParam)
status: MessageStatus = MessageStatus.WAITING
logical_blocks: LogicalTokenBlocks = field(
default_factory=LogicalTokenBlocks)
sender_id: int = -1
req_id: int = -1
adapter_name: str = None
arrive_time: float = 0.0
meta: Any = None
return_logits: bool = False
random_offsets: int = 0
@property
def history_len(self) -> int:
"""get history length."""
return len(self.history_token_ids)
@property
def session_id(self) -> int:
"""get session id."""
return self.session.session_id
def num_all_tokens(self) -> int:
"""num all tokens."""
return len(self.token_ids) + self.history_len
def update_token_ids(self, token_ids: Tensor, update_history: bool = True):
"""Update token ids, old token ids will be added to history."""
if update_history:
self.history_token_ids += self.token_ids.tolist()
if not isinstance(token_ids, Tensor):
token_ids = self.token_ids.new_tensor(token_ids)
if token_ids.dim() == 0:
token_ids = token_ids.unsqueeze(0)
self.token_ids = token_ids
self.random_offsets += 1
self.arrive_time = time.time()
def set_step(self, step: int):
"""set step."""
assert step <= self.history_len
history_token_ids = torch.tensor(self.history_token_ids,
dtype=torch.long)
new_history_ids = self.history_token_ids[:step]
new_token_ids = torch.cat([history_token_ids[step:], self.token_ids])
self.history_token_ids = new_history_ids
self.token_ids = new_token_ids
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from lmdeploy.pytorch.models import QLinear, QRMSNorm
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'InternLM2ForCausalLM': 'InternLM2DecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer',
'LlamaForCausalLM': 'LlamaDecoderLayer',
}
NORM_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMRMSNorm',
'InternLM2ForCausalLM': 'InternLM2RMSNorm',
'QWenLMHeadModel': 'RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm',
'LlamaForCausalLM': 'LlamaRMSNorm',
}
def convert_decoder_layer(module, norm_type):
"""Converts a given module's child layers from regular Linear or RMSNorm to
their Quantized versions (QLinear, QRMSNorm).
The conversion is done in place.
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
new_child = QLinear.from_float(child, initialization=False)
setattr(module, name, new_child)
elif type(child).__name__ == norm_type:
new_child = QRMSNorm.from_float(child, initialization=False)
setattr(module, name, new_child)
else:
convert_decoder_layer(child, norm_type)
def convert(module, layer_type, norm_type):
"""Recursively traverses through given PyTorch module and identifies child
layers that match the specified layer_type and norm_type for conversion to
their Quantized counterparts.
The conversion is done using the `convert_decoder_layer` function.
"""
for child in module.children():
if type(child).__name__ == layer_type:
convert_decoder_layer(child, norm_type)
else:
convert(child, layer_type, norm_type)
def convert_to_qmodules(model):
"""Convert all Linear and RMSNorm in the decoder layers of the model into
their Quantized versions (QLinear, QRMSNorm)."""
layer_type = LAYER_TYPE_MAP[type(model).__name__]
norm_type = NORM_TYPE_MAP[type(model).__name__]
convert(model, layer_type, norm_type)
return
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment