Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
fe851fbc
Commit
fe851fbc
authored
Mar 24, 2024
by
zhouxiang
Browse files
0.2.6版本新增文件补充
parent
e2d98ddc
Changes
220
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
7037 additions
and
0 deletions
+7037
-0
lmdeploy/pytorch/engine/engine.py
lmdeploy/pytorch/engine/engine.py
+1229
-0
lmdeploy/pytorch/engine/logits_process.py
lmdeploy/pytorch/engine/logits_process.py
+298
-0
lmdeploy/pytorch/engine/model_agent.py
lmdeploy/pytorch/engine/model_agent.py
+1175
-0
lmdeploy/pytorch/engine/request.py
lmdeploy/pytorch/engine/request.py
+592
-0
lmdeploy/pytorch/kernels/__init__.py
lmdeploy/pytorch/kernels/__init__.py
+15
-0
lmdeploy/pytorch/kernels/alibi_pagedattention.py
lmdeploy/pytorch/kernels/alibi_pagedattention.py
+530
-0
lmdeploy/pytorch/kernels/apply_rotary_pos_emb.py
lmdeploy/pytorch/kernels/apply_rotary_pos_emb.py
+215
-0
lmdeploy/pytorch/kernels/fill_kv_cache.py
lmdeploy/pytorch/kernels/fill_kv_cache.py
+153
-0
lmdeploy/pytorch/kernels/fused_rotary_emb.py
lmdeploy/pytorch/kernels/fused_rotary_emb.py
+126
-0
lmdeploy/pytorch/kernels/mbgmm.py
lmdeploy/pytorch/kernels/mbgmm.py
+310
-0
lmdeploy/pytorch/kernels/mbgmv.py
lmdeploy/pytorch/kernels/mbgmv.py
+257
-0
lmdeploy/pytorch/kernels/multinomial_sampling.py
lmdeploy/pytorch/kernels/multinomial_sampling.py
+100
-0
lmdeploy/pytorch/kernels/pagedattention.py
lmdeploy/pytorch/kernels/pagedattention.py
+529
-0
lmdeploy/pytorch/kernels/rearange_all_gather.py
lmdeploy/pytorch/kernels/rearange_all_gather.py
+134
-0
lmdeploy/pytorch/kernels/rerope_attention.py
lmdeploy/pytorch/kernels/rerope_attention.py
+352
-0
lmdeploy/pytorch/kernels/rms_norm.py
lmdeploy/pytorch/kernels/rms_norm.py
+111
-0
lmdeploy/pytorch/kernels/w8a8_triton_kernels.py
lmdeploy/pytorch/kernels/w8a8_triton_kernels.py
+610
-0
lmdeploy/pytorch/messages.py
lmdeploy/pytorch/messages.py
+241
-0
lmdeploy/pytorch/modeling/__init__.py
lmdeploy/pytorch/modeling/__init__.py
+1
-0
lmdeploy/pytorch/modeling/convert_to_qmodules.py
lmdeploy/pytorch/modeling/convert_to_qmodules.py
+59
-0
No files found.
lmdeploy/pytorch/engine/engine.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
asyncio
import
os
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
import
torch
from
lmdeploy.messages
import
(
EngineGenerationConfig
,
PytorchEngineConfig
,
ResponseType
)
from
lmdeploy.tokenizer
import
Tokenizer
from
lmdeploy.utils
import
get_logger
,
get_model
,
logging_timer
from
..adapter.adapter
import
ADAPTER_MANAGER
,
SchedulerAdapter
from
..check_env
import
check_adapters
,
check_env
,
check_model
from
..config
import
CacheConfig
,
SchedulerConfig
from
..messages
import
MessageStatus
,
SamplingParam
,
SchedulerSequence
from
..paging
import
Scheduler
from
.logits_process
import
FusedLogitsProcessor
,
SamplingInputs
from
.model_agent
import
AutoModelAgent
,
ModelInputs
from
.request
import
(
Request
,
RequestManager
,
RequestSender
,
RequestType
,
Response
)
logger
=
get_logger
(
'lmdeploy'
)
SeqList
=
List
[
SchedulerSequence
]
AdapterList
=
List
[
SchedulerAdapter
]
def
_div_up
(
x
,
n
):
"""perform div up."""
return
(
x
+
n
-
1
)
//
n
@
dataclass
class
InferOutput
:
"""The output of the model inference."""
session_id
:
int
token_ids
:
List
[
int
]
sender_id
:
int
req_id
:
int
meta
:
Any
=
None
finish
:
bool
=
False
logits
:
torch
.
Tensor
=
None
def
_paging_adapters
(
adapters
:
dict
,
model_agent
:
AutoModelAgent
,
scheduler
:
Scheduler
):
adapters
=
adapters
or
dict
()
weight_maps
=
[]
for
name
,
path
in
adapters
.
items
():
weight_map
=
scheduler
.
add_adapter
(
path
,
name
)
weight_map
.
block_table
=
torch
.
tensor
(
weight_map
.
block_table
)
weight_maps
.
append
(
weight_map
)
model_agent
.
paging_adapters
(
weight_maps
)
def
_tensorlize_block_offsets
(
block_offsets
):
"""tensorlize block_offsets."""
from
torch.nn.utils.rnn
import
pad_sequence
block_offsets
=
[
torch
.
from_numpy
(
off
)
for
off
in
block_offsets
]
block_offsets
=
pad_sequence
(
block_offsets
,
batch_first
=
True
)
return
block_offsets
def
_get_adapter_ids
(
seqs
:
SeqList
,
adapters
:
AdapterList
):
"""get adapter ids."""
adapter_names_map
=
dict
(
(
ada
.
name
,
idx
)
for
idx
,
ada
in
enumerate
(
adapters
))
adapter_ids
=
[
adapter_names_map
[
seq
.
adapter_name
]
for
seq
in
seqs
]
return
adapter_ids
def
_check_resp
(
resp
:
Response
,
state
:
ResponseType
,
warning_msg
:
str
=
None
):
"""check if response has state."""
if
isinstance
(
state
,
ResponseType
):
state
=
[
state
]
ret
=
resp
.
type
in
state
if
not
ret
and
warning_msg
is
not
None
:
logger
.
warning
(
warning_msg
)
return
ret
def
_check_resp_success
(
resp
:
Response
,
warning_msg
:
str
=
None
):
"""check if response success."""
return
_check_resp
(
resp
,
ResponseType
.
SUCCESS
,
warning_msg
)
async
def
async_try_add_session
(
req_sender
:
RequestSender
,
session_id
:
int
):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
resp
=
await
req_sender
.
async_send
(
RequestType
.
ADD_SESSION
,
dict
(
session_id
=
session_id
))
_check_resp
(
resp
,
[
ResponseType
.
SUCCESS
,
ResponseType
.
SESSION_REPEAT
],
(
f
'Can not add session
{
session_id
}
'
f
'with error:
{
resp
.
type
}
'
))
async
def
async_end
(
req_sender
:
RequestSender
,
session_id
:
int
):
"""End the given session."""
resp
=
await
req_sender
.
async_send
(
RequestType
.
END_SESSION
,
dict
(
session_id
=
session_id
))
_check_resp_success
(
resp
,
(
f
'Failed to end session:
{
session_id
}
. '
f
'Error:
{
resp
.
type
}
.'
))
async
def
async_cancel
(
req_sender
:
RequestSender
,
session_id
:
int
):
"""Stop current streaming inference."""
resp
=
await
req_sender
.
async_send
(
RequestType
.
STOP_SESSION
,
dict
(
session_id
=
session_id
))
_check_resp_success
(
resp
,
(
f
'Failed to cancel session:
{
session_id
}
. '
f
'Error:
{
resp
.
type
}
.'
))
def
try_add_session
(
req_sender
:
RequestSender
,
session_id
:
int
):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
resp
=
req_sender
.
send
(
RequestType
.
ADD_SESSION
,
dict
(
session_id
=
session_id
))
_check_resp
(
resp
,
[
ResponseType
.
SUCCESS
,
ResponseType
.
SESSION_REPEAT
],
(
f
'Can not add session
{
session_id
}
'
f
'with error:
{
resp
.
type
}
'
))
def
end
(
req_sender
:
RequestSender
,
session_id
:
int
):
"""End the given session."""
resp
=
req_sender
.
send
(
RequestType
.
END_SESSION
,
dict
(
session_id
=
session_id
))
_check_resp_success
(
resp
,
(
f
'Failed to end session:
{
session_id
}
. '
f
'Error:
{
resp
.
type
}
.'
))
def
cancel
(
req_sender
:
RequestSender
,
session_id
:
int
):
"""Stop current streaming inference."""
resp
=
req_sender
.
send
(
RequestType
.
STOP_SESSION
,
dict
(
session_id
=
session_id
))
_check_resp_success
(
resp
,
(
f
'Failed to cancel session:
{
session_id
}
. '
f
'Error:
{
resp
.
type
}
.'
))
class
Engine
:
"""The inference engine of lmdeploy pytorch.
Args:
model_path (str): The hugging face model path.
engine_config (PytorchEngineConfig): The config of the Engine.
trust_remote_code (bool): Trust remote code.
"""
def
__init__
(
self
,
model_path
:
str
,
engine_config
:
PytorchEngineConfig
=
None
,
trust_remote_code
:
bool
=
True
)
->
None
:
check_env
()
check_model
(
model_path
,
trust_remote_code
)
if
engine_config
.
adapters
is
not
None
:
check_adapters
(
list
(
engine_config
.
adapters
.
values
()))
if
engine_config
is
None
:
engine_config
=
PytorchEngineConfig
()
self
.
engine_config
=
engine_config
model_name
=
engine_config
.
model_name
tp
=
engine_config
.
tp
self
.
tp
=
tp
self
.
model_name
=
model_name
scheduler_config
=
SchedulerConfig
(
max_batches
=
engine_config
.
max_batch_size
,
max_session_len
=
engine_config
.
session_len
,
eviction_type
=
engine_config
.
eviction_type
,
prefill_interval
=
engine_config
.
prefill_interval
)
# block_size = 1 to enable unified paging
adapters
=
engine_config
.
adapters
cache_config
=
CacheConfig
(
block_size
=
engine_config
.
block_size
,
num_cpu_blocks
=
engine_config
.
num_cpu_blocks
,
num_gpu_blocks
=
engine_config
.
num_gpu_blocks
,
cache_max_entry_count
=
engine_config
.
cache_max_entry_count
,
max_prefill_token_num
=
engine_config
.
max_prefill_token_num
)
if
not
os
.
path
.
exists
(
model_path
):
model_path
=
get_model
(
model_path
,
engine_config
.
download_dir
,
engine_config
.
revision
)
self
.
model_agent
=
AutoModelAgent
.
from_pretrained
(
model_path
,
cache_config
=
cache_config
,
trust_remote_code
=
trust_remote_code
,
adapters
=
adapters
,
tp
=
tp
)
cache_config
=
self
.
model_agent
.
cache_config
self
.
scheduler
=
Scheduler
(
scheduler_config
,
cache_config
)
if
adapters
:
_paging_adapters
(
adapters
,
model_agent
=
self
.
model_agent
,
scheduler
=
self
.
scheduler
)
self
.
scheduler_config
=
scheduler_config
self
.
cache_config
=
cache_config
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
req_manager
=
self
.
_bind_request_manager
()
# create main thread
self
.
_start_loop
()
self
.
req_sender
=
self
.
req_manager
.
build_sender
()
self
.
_create_buffers
()
self
.
tokenizer
=
Tokenizer
(
model_path
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
,
engine_config
:
PytorchEngineConfig
=
None
,
trust_remote_code
:
bool
=
True
,
**
kwargs
):
"""lmdeploy python inference engine.
Args:
pretrained_model_name_or_path (str):
It could be one of the following options:
- i) A local directory path of a turbomind model which is
converted by `lmdeploy convert` command or download from
ii) and iii)
- ii) The model_id of a lmdeploy-quantized model hosted
inside a model repo on huggingface.co, such as
"InternLM/internlm-chat-20b-4bit",
"lmdeploy/llama2-chat-70b-4bit", etc.
- iii) The model_id of a model hosted inside a model repo
on huggingface.co, such as "InternLM/internlm-chat-7b",
"Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
and so on.
engine_config (PytorchEngineConfig): Pytorch engine config.
trust_remote_code (bool): Trust remote code
"""
logger
.
debug
(
f
'Get unexpected kwargs:
{
kwargs
}
'
)
return
cls
(
model_path
=
pretrained_model_name_or_path
,
engine_config
=
engine_config
,
trust_remote_code
=
trust_remote_code
)
def
_create_buffers
(
self
):
max_batches
=
self
.
scheduler_config
.
max_batches
# buffers to create inputs
self
.
_q_start_loc_buf
=
torch
.
arange
(
max_batches
)
self
.
_attention_mask_buf
=
torch
.
ones
(
max_batches
,
1
,
dtype
=
torch
.
long
)
self
.
_seq_length_buf
=
torch
.
ones
(
max_batches
,
dtype
=
torch
.
long
)
def
_bind_request_manager
(
self
):
"""bind request manager."""
req_manager
=
RequestManager
(
self
.
engine_config
.
thread_safe
)
req_manager
.
bind_func
(
RequestType
.
ADD_SESSION
,
self
.
_on_add_session
)
req_manager
.
bind_func
(
RequestType
.
STOP_SESSION
,
self
.
_on_stop_session
)
req_manager
.
bind_func
(
RequestType
.
END_SESSION
,
self
.
_on_end_session
)
req_manager
.
bind_func
(
RequestType
.
ADD_MESSAGE
,
self
.
_on_add_message
)
return
req_manager
def
_start_loop
(
self
):
"""start loop."""
return
self
.
req_manager
.
start_loop
(
self
.
async_loop
)
def
_on_add_session
(
self
,
reqs
:
Request
,
**
kwargs
):
"""on add session callback."""
for
req
in
reqs
:
session_id
=
req
.
data
[
'session_id'
]
resp_type
=
ResponseType
.
SESSION_REPEAT
if
session_id
not
in
self
.
scheduler
.
sessions
:
self
.
scheduler
.
add_session
(
session_id
)
resp_type
=
ResponseType
.
SUCCESS
self
.
req_manager
.
response
(
Response
(
type
=
resp_type
,
sender_id
=
req
.
sender_id
,
req_id
=
req
.
req_id
))
def
_on_stop_session
(
self
,
reqs
:
Request
,
**
kwargs
):
"""on stop session callback."""
for
req
in
reqs
:
session_id
=
req
.
data
[
'session_id'
]
resp_type
=
ResponseType
.
SESSION_NOT_EXIST
if
session_id
in
self
.
scheduler
.
sessions
:
self
.
scheduler
.
stop_session
(
session_id
)
resp_type
=
ResponseType
.
SUCCESS
self
.
req_manager
.
response
(
Response
(
type
=
resp_type
,
sender_id
=
req
.
sender_id
,
req_id
=
req
.
req_id
))
self
.
scheduler
.
update
()
def
_on_end_session
(
self
,
reqs
:
Request
,
**
kwargs
):
"""on end session callback."""
for
req
in
reqs
:
session_id
=
req
.
data
[
'session_id'
]
resp_type
=
ResponseType
.
SESSION_NOT_EXIST
if
session_id
in
self
.
scheduler
.
sessions
:
self
.
scheduler
.
end_session
(
session_id
)
resp_type
=
ResponseType
.
SUCCESS
self
.
req_manager
.
response
(
Response
(
type
=
resp_type
,
sender_id
=
req
.
sender_id
,
req_id
=
req
.
req_id
))
self
.
scheduler
.
update
()
def
_on_add_message
(
self
,
reqs
:
Request
,
**
kwargs
):
"""on add message callback."""
def
__update_bad_words
(
msg
):
"""update bad words."""
sampling_param
=
msg
.
sampling_param
eos_token_id
=
self
.
model_config
.
eos_token_id
if
eos_token_id
not
in
sampling_param
.
stop_words
:
sampling_param
.
stop_words
.
append
(
eos_token_id
)
if
sampling_param
.
ignore_eos
:
sampling_param
.
bad_words
.
append
(
eos_token_id
)
for
req
in
reqs
:
session_id
=
req
.
data
[
'session_id'
]
if
session_id
not
in
self
.
scheduler
.
sessions
:
self
.
req_manager
.
response
(
Response
(
type
=
ResponseType
.
SESSION_NOT_EXIST
,
sender_id
=
req
.
sender_id
,
req_id
=
req
.
req_id
))
continue
session_id
=
req
.
data
[
'session_id'
]
sess
=
self
.
scheduler
.
sessions
[
session_id
]
# TODO: support 1 session n sequence
if
len
(
sess
.
sequences
)
==
0
:
assert
len
(
req
.
data
[
'token_ids'
])
>
0
,
(
'Empty input is not allowed.'
)
sess
.
add_sequence
(
req
.
data
[
'token_ids'
],
sampling_param
=
req
.
data
[
'sampling_param'
],
adapter_name
=
req
.
data
[
'adapter_name'
],
return_logits
=
req
.
data
.
get
(
'return_logits'
,
False
))
msg
=
next
(
iter
(
sess
.
sequences
.
values
()))
__update_bad_words
(
msg
)
self
.
scheduler
.
add_sequence
(
msg
)
else
:
msg
=
next
(
iter
(
sess
.
sequences
.
values
()))
msg
.
update_token_ids
(
req
.
data
[
'token_ids'
])
msg
.
num_new_tokens
=
0
msg
.
sampling_param
=
req
.
data
[
'sampling_param'
]
msg
.
return_logits
=
req
.
data
.
get
(
'return_logits'
,
False
)
msg
.
status
=
MessageStatus
.
WAITING
__update_bad_words
(
msg
)
msg
.
sender_id
=
req
.
sender_id
msg
.
req_id
=
req
.
req_id
self
.
scheduler
.
update
()
@
property
def
model_config
(
self
):
"""model config."""
return
self
.
model_agent
.
model_config
@
property
def
gpu_count
(
self
):
return
self
.
tp
@
property
def
session_len
(
self
):
return
self
.
scheduler_config
.
max_session_len
def
create_instance
(
self
,
cuda_stream_id
=
0
):
"""Create a turbomind instance.
Args:
cuda_stream_id(int): identity of a cuda stream
Returns:
EngineInstance: an instance of turbomind
"""
return
EngineInstance
(
self
)
async
def
async_add_session
(
self
,
session_id
:
int
):
"""Add new session."""
return
await
async_try_add_session
(
self
.
req_sender
,
session_id
)
def
add_session
(
self
,
session_id
:
int
):
"""Add new session."""
return
try_add_session
(
self
.
req_sender
,
session_id
)
async
def
async_stop_session
(
self
,
session_id
:
int
):
"""Stop the given session."""
return
await
async_cancel
(
self
.
req_sender
,
session_id
)
def
stop_session
(
self
,
session_id
:
int
):
"""Add new session."""
return
cancel
(
self
.
req_sender
,
session_id
)
async
def
async_end_session
(
self
,
session_id
:
int
):
"""End the given session."""
return
await
async_end
(
self
.
req_sender
,
session_id
)
def
end_session
(
self
,
session_id
:
int
):
"""Add new session."""
return
end
(
self
.
req_sender
,
session_id
)
@
logging_timer
(
'CreateModelInputs'
,
logger
)
@
torch
.
inference_mode
()
def
create_model_inputs
(
self
,
messages
:
SeqList
,
adapters
:
AdapterList
):
"""create model inputs from messages.
Args:
messages (SeqList): The input messages.
adapters (AdapterList): Adapters.
"""
def
__get_history_length
():
"""get history length."""
if
self
.
model_config
.
sliding_window
>
0
:
history_lengths
=
[]
for
msg
in
messages
:
num_real_blocks
=
len
(
msg
.
logical_blocks
)
num_all_blocks
=
_div_up
(
msg
.
num_all_tokens
(),
msg
.
block_size
)
num_drop_blocks
=
num_all_blocks
-
num_real_blocks
num_drop_tokens
=
num_drop_blocks
*
msg
.
block_size
history_lengths
.
append
(
msg
.
history_len
-
num_drop_tokens
)
return
history_lengths
else
:
return
[
msg
.
history_len
for
msg
in
messages
]
history_lengths
=
__get_history_length
()
token_ids
=
[
msg
.
token_ids
for
msg
in
messages
]
meta
=
messages
[
0
].
meta
if
isinstance
(
token_ids
[
0
],
int
):
token_ids
=
[
token_ids
]
batch_size
=
len
(
messages
)
input_ids
=
torch
.
cat
(
token_ids
)
is_decoding
=
input_ids
.
size
(
0
)
==
batch_size
if
not
is_decoding
:
seq_length
=
[
tokens
.
size
(
0
)
for
tokens
in
token_ids
]
seq_length
=
torch
.
tensor
(
seq_length
,
dtype
=
torch
.
long
)
max_seq_len
=
max
(
seq_length
)
q_start_loc
=
seq_length
.
cumsum
(
0
)
-
seq_length
mask_range
=
torch
.
arange
(
max_seq_len
)[
None
,
:]
attention_mask
=
(
mask_range
<
seq_length
[:,
None
]).
long
()
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
+=
position_ids
.
new_tensor
(
history_lengths
).
unsqueeze
(
-
1
)
else
:
q_start_loc
=
self
.
_q_start_loc_buf
[:
batch_size
]
attention_mask
=
self
.
_attention_mask_buf
[:
batch_size
]
seq_length
=
self
.
_seq_length_buf
[:
batch_size
]
position_ids
=
q_start_loc
.
new_tensor
(
history_lengths
).
unsqueeze
(
-
1
)
# TODO: get block offsets is slow when block_size = 1
block_offsets
=
self
.
scheduler
.
get_block_tables
(
messages
)
block_offsets
=
_tensorlize_block_offsets
(
block_offsets
)
local_adapter_ids
=
None
global_adapter_ids
=
None
adapter_offsets
=
None
max_rank
=
0
if
ADAPTER_MANAGER
.
num_adapters
()
>
1
:
local_adapter_ids
=
_get_adapter_ids
(
messages
,
adapters
)
local_adapter_ids
=
seq_length
.
new_tensor
(
local_adapter_ids
)
adapter_offsets
=
self
.
scheduler
.
get_block_tables
(
adapters
)
adapter_offsets
=
_tensorlize_block_offsets
(
adapter_offsets
)
global_adapter_ids
=
[
ada
.
idx
for
ada
in
adapters
]
global_adapter_ids
=
seq_length
.
new_tensor
(
global_adapter_ids
)
ranks
=
[
ada
.
rank
for
ada
in
adapters
]
max_rank
=
max
(
ranks
)
# add batch dim [bs=1, seq_len]
if
input_ids
.
ndim
==
1
:
input_ids
=
input_ids
.
unsqueeze
(
0
)
return
ModelInputs
(
input_ids
=
input_ids
,
seq_length
=
seq_length
,
attention_mask
=
attention_mask
,
block_offsets
=
block_offsets
,
position_ids
=
position_ids
,
q_start_loc
=
q_start_loc
,
history_lengths
=
history_lengths
,
is_decoding
=
is_decoding
,
local_adapter_ids
=
local_adapter_ids
,
global_adapter_ids
=
global_adapter_ids
,
adapter_offsets
=
adapter_offsets
,
max_rank
=
max_rank
,
meta
=
meta
)
def
_stopping_criteria
(
self
,
msg
:
SchedulerSequence
,
next_token_id
:
int
):
"""Check if the message should stop.
Args:
msg (SchedulerSequence): The input message.
next_token_id (int): The next token id from inference result.
Returns:
bool: Whether the message should be stopped.
"""
def
_check_stop_word
(
sampling_param
,
next_token_id
):
if
sampling_param
.
ignore_eos
:
return
False
return
(
sampling_param
.
stop_words
is
not
None
and
next_token_id
in
sampling_param
.
stop_words
)
def
_check_request_len
(
msg
):
return
msg
.
num_new_tokens
>=
msg
.
sampling_param
.
max_new_tokens
def
_check_session_len
(
msg
,
max_session_len
):
if
max_session_len
is
None
:
return
False
session_len
=
msg
.
num_all_tokens
()
+
1
return
session_len
>=
max_session_len
sampling_param
=
msg
.
sampling_param
if
_check_stop_word
(
sampling_param
,
next_token_id
):
return
True
if
_check_request_len
(
msg
):
return
True
if
_check_session_len
(
msg
,
self
.
scheduler_config
.
max_session_len
):
return
True
return
False
@
logging_timer
(
'SamplingLogits'
,
logger
)
async
def
async_sampling_logits
(
self
,
logits
:
torch
.
Tensor
,
running
:
SeqList
,
inputs
:
ModelInputs
):
"""sampling logits."""
def
_gather_history
(
seqs
:
SeqList
,
device
:
torch
.
device
):
"""gather history."""
batch
=
len
(
seqs
)
max_len
=
max
(
seq
.
history_len
for
seq
in
seqs
)
output
=
torch
.
full
((
batch
,
max_len
),
self
.
model_config
.
bos_token_id
,
dtype
=
torch
.
int64
)
for
idx
,
seq
in
enumerate
(
seqs
):
h_len
=
seq
.
history_len
h_ids
=
output
.
new_tensor
(
seq
.
history_token_ids
)
output
[
idx
,
:
h_len
]
=
h_ids
return
output
.
to
(
device
)
is_decoding
=
inputs
.
is_decoding
# TODO: support repetition_penalty
if
not
is_decoding
:
seq_length
=
inputs
.
seq_length
last_idx
=
seq_length
.
cumsum
(
-
1
)
-
1
split_logits
=
logits
[
last_idx
,
:]
else
:
# most step share the same sampling parameters
split_logits
=
logits
split_logits
=
split_logits
.
cuda
()
sampling_inputs
=
SamplingInputs
.
from_sampling_params
(
running
)
sampling_inputs
=
sampling_inputs
.
to_device
(
split_logits
.
device
)
input_ids
=
None
if
sampling_inputs
.
repetition_penalty
is
not
None
:
input_ids
=
_gather_history
(
running
,
split_logits
.
device
)
logits_processor
=
FusedLogitsProcessor
(
sampling_inputs
)
with
torch
.
inference_mode
(),
torch
.
cuda
.
stream
(
self
.
stream
):
logits
=
logits_processor
(
input_ids
,
split_logits
)
next_token_ids
=
logits_processor
.
sampling
(
logits
)
await
asyncio
.
get_event_loop
().
run_in_executor
(
None
,
self
.
stream
.
synchronize
)
next_token_ids
=
next_token_ids
.
cpu
()
return
next_token_ids
,
split_logits
@
logging_timer
(
'UpdateRunning'
,
logger
)
def
update_running
(
self
,
running
:
SeqList
,
next_token_ids
:
torch
.
Tensor
,
meta
:
Any
):
"""update scheduler."""
for
token
,
msg
in
zip
(
next_token_ids
,
running
):
msg
.
meta
=
meta
msg
.
update_token_ids
(
token
)
msg
.
num_new_tokens
+=
1
if
msg
.
num_new_tokens
>
msg
.
sampling_param
.
max_new_tokens
:
msg
.
token_ids
=
torch
.
empty
((
0
,
),
dtype
=
torch
.
long
)
if
self
.
_stopping_criteria
(
msg
,
token
):
msg
.
status
=
MessageStatus
.
STOPPED
def
_can_output_token
(
self
,
token
:
torch
.
Tensor
,
msg
:
SchedulerSequence
):
"""check if output is necessary."""
if
isinstance
(
token
,
torch
.
Tensor
):
token
=
token
.
item
()
stop_words
=
msg
.
sampling_param
.
stop_words
if
stop_words
is
not
None
and
token
in
stop_words
:
return
False
return
True
@
logging_timer
(
'ModelForward'
,
logger
)
async
def
_async_model_forward
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
Dict
,
swap_out_map
:
Dict
):
"""model forward."""
max_prefill_token_num
=
self
.
cache_config
.
max_prefill_token_num
swap_done
=
False
class
_LogitsGather
:
"""logits gather."""
def
__init__
(
self
,
max_seq_len
):
self
.
_max_seq_len
=
max_seq_len
self
.
_start
=
0
self
.
_out_logits
=
None
def
gather
(
self
,
output
):
"""gather."""
logits
=
output
[
'logits'
]
out_logits
=
self
.
_out_logits
start
=
self
.
_start
seq_len
=
logits
.
size
(
-
2
)
if
out_logits
is
None
:
out_logits
=
logits
.
new_empty
(
1
,
self
.
_max_seq_len
,
logits
.
size
(
-
1
),
device
=
'cpu'
)
out_logits
[:,
start
:
start
+
seq_len
].
copy_
(
logits
,
non_blocking
=
True
)
self
.
_start
=
start
+
seq_len
self
.
_out_logits
=
out_logits
def
get_logits
(
self
):
"""get logits."""
torch
.
cuda
.
synchronize
()
return
self
.
_out_logits
async
def
__forward
(
inputs
):
"""forward."""
nonlocal
swap_done
,
swap_in_map
,
swap_out_map
if
swap_done
:
return
await
self
.
model_agent
.
async_forward
(
inputs
,
swap_in_map
=
dict
(),
swap_out_map
=
dict
())
else
:
swap_done
=
True
return
await
self
.
model_agent
.
async_forward
(
inputs
,
swap_in_map
=
swap_in_map
,
swap_out_map
=
swap_out_map
)
async
def
__long_context_single_forward
(
inputs
,
index
):
"""one large sequence."""
new_input
=
inputs
.
slice
(
index
,
index
+
1
)
max_seq_len
=
new_input
.
seq_length
[
0
]
new_inputs
=
new_input
.
split
(
max_prefill_token_num
,
self
.
cache_config
.
block_size
)
logits_gather
=
_LogitsGather
(
max_seq_len
)
for
inp
in
new_inputs
:
tmp_out
=
await
__forward
(
inp
)
logits_gather
.
gather
(
tmp_out
)
tmp_out
[
'logits'
]
=
logits_gather
.
get_logits
()
return
tmp_out
async
def
__long_context_batched_forward
(
inputs
,
start
,
end
):
"""batched."""
new_inputs
=
inputs
.
slice
(
start
,
end
)
return
await
__forward
(
new_inputs
)
async
def
__long_context_forward
(
inputs
):
"""forward for long context."""
seq_len
=
inputs
.
seq_length
max_seq_len
=
inputs
.
input_ids
.
size
(
1
)
batch_size
=
seq_len
.
size
(
0
)
indices
=
[]
token_count
=
0
idx
=
0
logits_gather
=
_LogitsGather
(
max_seq_len
)
while
idx
<
batch_size
:
slen
=
seq_len
[
idx
]
if
token_count
==
0
and
slen
>
max_prefill_token_num
:
tmp_out
=
await
__long_context_single_forward
(
inputs
,
idx
)
logits_gather
.
gather
(
tmp_out
)
tmp_out
.
pop
(
'logits'
,
None
)
idx
+=
1
elif
token_count
+
slen
>
max_prefill_token_num
:
tmp_out
=
await
__long_context_batched_forward
(
inputs
,
indices
[
0
],
idx
)
logits_gather
.
gather
(
tmp_out
)
tmp_out
.
pop
(
'logits'
,
None
)
indices
=
[]
token_count
=
0
else
:
indices
.
append
(
idx
)
token_count
+=
slen
idx
+=
1
if
token_count
>
0
:
tmp_out
=
await
__long_context_batched_forward
(
inputs
,
indices
[
0
],
idx
)
logits_gather
.
gather
(
tmp_out
)
tmp_out
[
'logits'
]
=
logits_gather
.
get_logits
()
return
tmp_out
if
inputs
.
input_ids
.
numel
()
<
max_prefill_token_num
:
return
await
__forward
(
inputs
)
else
:
return
await
__long_context_forward
(
inputs
)
@
logging_timer
(
'AsyncStep'
,
logger
)
async
def
async_step
(
self
,
is_prefill
:
bool
,
return_logits
:
bool
=
False
):
"""one step inference. Used to perform streaming chat.
Returns:
Dict[int, InferOutput]: The output of each session.
"""
# schedule
schedule_output
=
self
.
scheduler
.
schedule
(
is_prefill
=
is_prefill
)
running
:
SeqList
=
schedule_output
.
running
swap_in_map
=
schedule_output
.
swap_in_map
swap_out_map
=
schedule_output
.
swap_out_map
adapters
=
schedule_output
.
adapters
if
len
(
running
)
==
0
:
return
dict
()
inputs
=
self
.
create_model_inputs
(
running
,
adapters
)
logger
.
debug
(
f
'<AsyncStep>: batch_size=
{
len
(
running
)
}
'
f
'num_tokens=
{
inputs
.
input_ids
.
size
(
-
1
)
}
'
)
# inference
output
=
await
self
.
_async_model_forward
(
inputs
,
swap_in_map
=
swap_in_map
,
swap_out_map
=
swap_out_map
)
custom_outputs
=
output
[
'custom_outputs'
]
logits
=
output
[
'logits'
]
logits
=
logits
[
0
]
# [bs, seq, prob] -> [seq, prob]
next_token_ids
,
_
=
await
self
.
async_sampling_logits
(
logits
,
running
,
inputs
)
self
.
update_running
(
running
,
next_token_ids
,
custom_outputs
)
self
.
scheduler
.
update
()
# generate output
outputs
:
Dict
[
int
,
InferOutput
]
=
dict
()
for
idx
,
msg
in
enumerate
(
running
):
next_id
=
next_token_ids
[
idx
]
session_id
=
msg
.
session_id
if
self
.
_can_output_token
(
next_id
,
msg
):
out_token_ids
=
[
next_id
.
item
()]
else
:
out_token_ids
=
[]
out
=
InferOutput
(
session_id
=
session_id
,
sender_id
=
msg
.
sender_id
,
req_id
=
msg
.
req_id
,
finish
=
(
msg
.
status
==
MessageStatus
.
STOPPED
),
token_ids
=
out_token_ids
,
)
outputs
[
session_id
]
=
out
if
msg
.
return_logits
:
start
=
inputs
.
q_start_loc
[
idx
]
seqlen
=
inputs
.
seq_length
[
idx
]
outputs
[
msg
.
session_id
].
logits
=
logits
[
start
:
start
+
seqlen
]
return
outputs
async
def
async_batched_infer
(
self
,
session_ids
:
List
[
int
],
token_ids
:
List
[
List
[
int
]]
=
None
,
gen_config
:
EngineGenerationConfig
=
None
,
adapter_names
:
List
[
str
]
=
None
,
keep_cache
:
bool
=
False
):
"""Send inference request.
Args:
session_ids (List[int]): The session id.
token_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
adapter_names (List[str]): The name of the adapters.
keep_cache (bool): Keep kv cache after infer.
Returns:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
batch_size
=
len
(
token_ids
)
assert
len
(
session_ids
)
==
batch_size
if
adapter_names
is
not
None
:
assert
len
(
adapter_names
)
==
batch_size
else
:
adapter_names
=
[
None
for
_
in
range
(
batch_size
)]
async
def
_add_sessions
(
session_ids
):
for
session_id
in
session_ids
:
await
self
.
async_add_session
(
session_id
)
async
def
_add_messages
(
session_ids
,
token_ids
):
add_msgs
=
[]
sampling_param
=
SamplingParam
.
from_gen_config
(
gen_config
)
for
session_id
,
token_id
,
adapter_name
in
zip
(
session_ids
,
token_ids
,
adapter_names
):
msg
=
dict
(
token_ids
=
token_id
,
session_id
=
session_id
,
sampling_param
=
sampling_param
,
adapter_name
=
adapter_name
)
add_msgs
.
append
(
msg
)
req_types
=
[
RequestType
.
ADD_MESSAGE
]
*
batch_size
req_ids
=
await
self
.
req_sender
.
async_batched_send_async
(
req_types
,
data
=
add_msgs
)
return
req_ids
await
_add_sessions
(
session_ids
)
req_ids
=
await
_add_messages
(
session_ids
,
token_ids
)
# receive messages
req_idx_map
=
dict
(
zip
(
req_ids
,
range
(
len
(
req_ids
))))
output_token_ids
=
[
list
()
for
_
in
req_ids
]
status
=
0
finish_count
=
batch_size
while
finish_count
:
if
not
self
.
req_manager
.
is_loop_alive
():
logger
.
error
(
'Engine loop is not alive.'
)
status
=
1
break
resp
=
await
self
.
req_sender
.
async_recv_any
()
if
resp
.
req_id
not
in
req_ids
:
continue
idx
=
req_idx_map
[
resp
.
req_id
]
token_ids
=
output_token_ids
[
idx
]
if
resp
.
type
==
ResponseType
.
SUCCESS
:
token_ids
+=
resp
.
data
[
'token_ids'
]
elif
resp
.
type
==
ResponseType
.
FINISH
:
token_ids
+=
resp
.
data
[
'token_ids'
]
if
not
keep_cache
:
session_id
=
session_ids
[
idx
]
await
self
.
async_end_session
(
session_id
=
session_id
)
finish_count
-=
1
else
:
logger
.
error
(
f
'Unexpected response:
{
resp
.
type
}
'
)
status
=
1
break
output_token_len
=
[
len
(
token_ids
)
for
token_ids
in
output_token_ids
]
return
(
status
,
output_token_ids
,
output_token_len
)
def
batched_infer
(
self
,
session_ids
:
List
[
int
],
token_ids
:
List
[
List
[
int
]]
=
None
,
gen_config
:
EngineGenerationConfig
=
None
,
adapter_names
:
List
[
str
]
=
None
,
keep_cache
:
bool
=
False
):
"""batched infer."""
coro
=
self
.
async_batched_infer
(
session_ids
,
token_ids
,
gen_config
=
gen_config
,
adapter_names
=
adapter_names
,
keep_cache
=
keep_cache
)
return
self
.
req_sender
.
run_until_complete
(
coro
)
def
decode
(
self
,
input_ids
,
steps
:
List
[
int
]
=
None
,
sequence_start
:
bool
=
True
,
sequence_end
:
bool
=
True
,
adapter_names
:
List
[
str
]
=
None
):
"""Perform context decode on input tokens.
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
from
torch.nn.utils.rnn
import
pad_sequence
logger
.
debug
(
'Decoding logits.'
)
batch_size
=
len
(
input_ids
)
def
__add_messages
(
session_ids
,
input_ids
,
adapter_names
):
add_msgs
=
[]
sampling_param
=
SamplingParam
(
max_new_tokens
=
0
)
for
session_id
,
token_id
,
adapter_name
in
zip
(
session_ids
,
input_ids
,
adapter_names
):
msg
=
dict
(
token_ids
=
token_id
,
session_id
=
session_id
,
sampling_param
=
sampling_param
,
adapter_name
=
adapter_name
,
return_logits
=
True
)
add_msgs
.
append
(
msg
)
req_types
=
[
RequestType
.
ADD_MESSAGE
]
*
batch_size
req_ids
=
self
.
req_sender
.
batched_send_async
(
req_types
,
data
=
add_msgs
)
return
req_ids
if
steps
is
not
None
:
assert
batch_size
==
len
(
steps
)
if
adapter_names
is
None
:
adapter_names
=
[
None
]
*
batch_size
assert
batch_size
==
len
(
adapter_names
)
session_ids
=
tuple
(
range
(
batch_size
))
if
sequence_start
:
for
sid
in
session_ids
:
self
.
req_sender
.
send
(
RequestType
.
END_SESSION
,
dict
(
session_id
=
sid
))
self
.
add_session
(
sid
)
req_ids
=
__add_messages
(
session_ids
,
input_ids
,
adapter_names
)
req_idx_map
=
dict
(
zip
(
req_ids
,
range
(
len
(
req_ids
))))
finish_count
=
batch_size
ret
=
[
None
]
*
batch_size
while
finish_count
>
0
:
resp
=
self
.
req_sender
.
recv_any
()
if
resp
.
req_id
not
in
req_ids
:
continue
assert
resp
.
type
==
ResponseType
.
FINISH
idx
=
req_idx_map
[
resp
.
req_id
]
ret
[
idx
]
=
resp
.
data
[
'logits'
]
finish_count
-=
1
ret
=
pad_sequence
(
ret
,
True
)
if
sequence_end
:
for
sid
in
session_ids
:
self
.
end_session
(
sid
)
return
ret
async
def
async_loop
(
self
):
"""Main loop of the engine.
Each engine instance would communicate with the engine by queue.
"""
def
_send_resp
(
step_tokens
):
"""send response callback."""
for
_
,
out
in
step_tokens
.
items
():
if
out
.
finish
:
resp_type
=
ResponseType
.
FINISH
else
:
resp_type
=
ResponseType
.
SUCCESS
self
.
req_manager
.
response
(
Response
(
type
=
resp_type
,
sender_id
=
out
.
sender_id
,
req_id
=
out
.
req_id
,
data
=
dict
(
token_ids
=
out
.
token_ids
,
logits
=
out
.
logits
),
))
prefill_interval
=
self
.
scheduler_config
.
prefill_interval
prefill_counter
=
prefill_interval
while
True
:
if
not
self
.
req_manager
.
has_requests
(
)
and
not
self
.
scheduler
.
has_unfinished
():
await
asyncio
.
sleep
(
0.01
)
continue
self
.
req_manager
.
step
()
# forward
if
self
.
scheduler
.
has_unfinished
():
has_running
=
self
.
scheduler
.
has_running
()
is_prefill
=
not
prefill_counter
or
not
has_running
if
is_prefill
:
prefill_counter
=
prefill_interval
with
torch
.
inference_mode
():
step_tokens
:
Dict
[
int
,
InferOutput
]
=
await
self
.
async_step
(
is_prefill
=
is_prefill
)
prefill_counter
-=
1
# send response
_send_resp
(
step_tokens
)
class
EngineInstance
:
"""Instance of TurboMind.
Args:
engine (Engine): engine
"""
def
__init__
(
self
,
engine
:
Engine
):
self
.
engine
=
engine
self
.
req_sender
=
engine
.
req_manager
.
build_sender
()
def
__del__
(
self
):
"""Destructor."""
self
.
engine
.
req_manager
.
senders
.
pop
(
self
.
req_sender
.
sender_id
)
async
def
_async_try_add_session
(
self
,
session_id
:
int
):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
return
await
async_try_add_session
(
self
.
req_sender
,
session_id
)
def
_try_add_session
(
self
,
session_id
:
int
):
"""Add new session.
Args:
session_id (int): The session id to add.
"""
return
try_add_session
(
self
.
req_sender
,
session_id
)
async
def
async_stream_infer
(
self
,
session_id
:
int
,
input_ids
:
List
[
int
],
gen_config
:
EngineGenerationConfig
=
None
,
adapter_name
:
str
=
None
,
**
kwargs
):
"""Send stream inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
adapter_name (str): The lora adapter name.
Yields:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
gen_config
=
gen_config
or
EngineGenerationConfig
()
sampling_param
=
SamplingParam
.
from_gen_config
(
gen_config
=
gen_config
)
await
async_try_add_session
(
self
.
req_sender
,
session_id
)
msg
=
dict
(
token_ids
=
input_ids
,
session_id
=
session_id
,
sampling_param
=
sampling_param
,
adapter_name
=
adapter_name
,
)
req_id
=
await
self
.
req_sender
.
async_send_async
(
RequestType
.
ADD_MESSAGE
,
msg
)
token_ids
=
[]
while
True
:
if
not
self
.
req_sender
.
is_loop_alive
():
yield
(
ResponseType
.
ENGINE_STOP_ERROR
,
[],
0
)
break
resp
=
await
self
.
req_sender
.
async_recv
(
req_id
)
if
resp
.
req_id
!=
req_id
:
continue
if
resp
.
type
==
ResponseType
.
SUCCESS
:
token_ids
+=
resp
.
data
[
'token_ids'
]
yield
(
resp
.
type
,
token_ids
,
len
(
token_ids
))
elif
resp
.
type
==
ResponseType
.
FINISH
:
token_ids
+=
resp
.
data
[
'token_ids'
]
yield
(
resp
.
type
,
token_ids
,
len
(
token_ids
))
break
else
:
yield
(
resp
.
type
,
[],
0
)
break
async
def
async_infer
(
self
,
session_id
:
int
,
input_ids
:
List
[
int
]
=
None
,
gen_config
:
EngineGenerationConfig
=
None
,
**
kwargs
):
"""Send inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
Returns:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
token_ids
=
[]
async
for
outputs
in
self
.
async_stream_infer
(
session_id
,
input_ids
,
gen_config
=
gen_config
,
**
kwargs
):
status
,
tmp_ids
,
_
=
outputs
if
status
not
in
[
ResponseType
.
SUCCESS
,
ResponseType
.
FINISH
]:
return
(
status
,
token_ids
,
len
(
token_ids
))
token_ids
=
tmp_ids
return
(
0
,
token_ids
,
len
(
token_ids
))
def
stream_infer
(
self
,
session_id
:
int
,
input_ids
:
List
[
int
],
gen_config
:
EngineGenerationConfig
=
None
,
adapter_name
:
str
=
None
,
**
kwargs
):
"""Send stream inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
adapter_name (str): The lora adapter name.
Yields:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
def
__call_async
():
"""call async."""
coro_gen
=
self
.
async_stream_infer
(
session_id
,
input_ids
,
gen_config
,
adapter_name
,
**
kwargs
)
while
True
:
try
:
yield
self
.
req_sender
.
run_until_complete
(
coro_gen
.
__anext__
())
except
StopAsyncIteration
:
break
if
not
self
.
req_sender
.
is_thread_safe
():
yield
from
__call_async
()
return
gen_config
=
gen_config
or
EngineGenerationConfig
()
sampling_param
=
SamplingParam
.
from_gen_config
(
gen_config
=
gen_config
)
try_add_session
(
self
.
req_sender
,
session_id
)
msg
=
dict
(
token_ids
=
input_ids
,
session_id
=
session_id
,
sampling_param
=
sampling_param
,
adapter_name
=
adapter_name
,
)
req_id
=
self
.
req_sender
.
send_async
(
RequestType
.
ADD_MESSAGE
,
msg
)
token_ids
=
[]
while
True
:
if
not
self
.
req_sender
.
is_loop_alive
():
yield
(
ResponseType
.
ENGINE_STOP_ERROR
,
[],
0
)
break
resp
=
self
.
req_sender
.
recv
(
req_id
)
if
resp
.
req_id
!=
req_id
:
continue
if
resp
.
type
==
ResponseType
.
SUCCESS
:
token_ids
+=
resp
.
data
[
'token_ids'
]
yield
(
resp
.
type
,
token_ids
,
len
(
token_ids
))
elif
resp
.
type
==
ResponseType
.
FINISH
:
token_ids
+=
resp
.
data
[
'token_ids'
]
yield
(
resp
.
type
,
token_ids
,
len
(
token_ids
))
break
else
:
yield
(
resp
.
type
,
[],
0
)
break
def
infer
(
self
,
session_id
:
int
,
input_ids
:
List
[
int
]
=
None
,
gen_config
:
EngineGenerationConfig
=
None
,
**
kwargs
):
"""Send inference request.
Args:
session_id (int): The session id.
input_ids (List[int]): The input token ids.
gen_config (EngineGenerationConfig): The sampling parameters.
Returns:
int: Error flags. 0 if success.
List[int]: The streaming output tokens.
int: The number of the output tokens.
"""
token_ids
=
[]
for
outputs
in
self
.
stream_infer
(
session_id
,
input_ids
,
gen_config
=
gen_config
,
**
kwargs
):
status
,
tmp_ids
,
_
=
outputs
if
status
not
in
[
ResponseType
.
SUCCESS
,
ResponseType
.
FINISH
]:
return
(
status
,
token_ids
,
len
(
token_ids
))
token_ids
=
tmp_ids
return
(
0
,
token_ids
,
len
(
token_ids
))
async
def
async_end
(
self
,
session_id
:
int
):
"""End the given session."""
return
await
async_end
(
self
.
req_sender
,
session_id
)
def
end
(
self
,
session_id
:
int
):
"""End the given session."""
return
end
(
self
.
req_sender
,
session_id
)
async
def
async_cancel
(
self
,
session_id
:
int
):
"""Stop current streaming inference."""
return
await
async_cancel
(
self
.
req_sender
,
session_id
)
def
cancel
(
self
,
session_id
:
int
):
"""Stop current streaming inference."""
return
cancel
(
self
.
req_sender
,
session_id
)
def
decode
(
self
,
input_ids
,
steps
:
List
[
int
]
=
None
,
sequence_start
:
bool
=
True
,
sequence_end
:
bool
=
True
,
adapter_names
:
List
[
str
]
=
None
):
"""Perform context decode on input tokens.
Args:
input_ids (numpy.ndarray): the batch of input token ids
steps (List[int]): the offset of the k/v cache
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
adapter_names (List[str]): The name of the adapters.
"""
return
self
.
engine
.
decode
(
input_ids
,
steps
=
steps
,
sequence_start
=
sequence_start
,
sequence_end
=
sequence_end
,
adapter_names
=
adapter_names
)
lmdeploy/pytorch/engine/logits_process.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
dataclasses
import
asdict
,
dataclass
from
typing
import
Dict
,
List
import
torch
from
transformers.generation.logits_process
import
LogitsWarper
from
..messages
import
SchedulerSequence
def
_process_temperature
(
scores
:
torch
.
Tensor
,
temperature
:
torch
.
Tensor
,
inplace
:
bool
=
True
):
"""process temperature."""
temperature
=
temperature
.
to
(
scores
.
dtype
)
if
not
inplace
:
scores
=
scores
/
temperature
[:,
None
]
else
:
scores
/=
temperature
[:,
None
]
return
scores
def
_process_bad_words
(
scores
:
torch
.
Tensor
,
bad_words
:
torch
.
LongTensor
,
filter_value
:
float
=
-
float
(
'inf'
),
inplace
:
bool
=
True
):
"""process bad words."""
batch_size
=
scores
.
size
(
0
)
batch_idx
=
torch
.
arange
(
batch_size
,
device
=
scores
.
device
)
filtered_scores
=
scores
[
batch_idx
[:,
None
],
bad_words
]
filtered_scores
[
bad_words
>=
0
]
=
filter_value
if
not
inplace
:
scores
=
scores
.
clone
()
scores
[
batch_idx
[:,
None
],
bad_words
]
=
filtered_scores
return
scores
def
_process_repetition_penalty
(
scores
:
torch
.
Tensor
,
input_ids
:
torch
.
LongTensor
,
penalty
:
torch
.
Tensor
,
inplace
:
bool
=
True
):
"""process repetition penalty."""
score
=
torch
.
gather
(
scores
,
1
,
input_ids
)
penalty
=
penalty
.
to
(
score
.
dtype
)
score
=
torch
.
where
(
score
<
0
,
score
*
penalty
[:,
None
],
score
/
penalty
[:,
None
])
if
not
inplace
:
scores
=
scores
.
clone
()
scores
.
scatter_
(
1
,
input_ids
,
score
)
return
scores
def
_filter_topk_sorted
(
scores
:
torch
.
Tensor
,
topk
:
torch
.
LongTensor
,
filter_value
:
float
=
-
float
(
'inf'
),
inplace
:
bool
=
True
):
"""filter topk on sorted scores."""
filter_value
=
-
float
(
'inf'
)
num_tokens
=
scores
.
size
(
1
)
token_idx
=
torch
.
arange
(
num_tokens
,
device
=
scores
.
device
)
mask
=
token_idx
[
None
,
:]
>=
topk
[:,
None
]
if
inplace
:
scores
.
masked_fill_
(
mask
,
filter_value
)
else
:
scores
=
scores
.
masked_fill
(
mask
,
filter_value
)
return
scores
def
_filter_topp_sorted
(
scores
:
torch
.
Tensor
,
topp
:
torch
.
Tensor
,
filter_value
:
float
=
-
float
(
'inf'
),
inplace
:
bool
=
True
):
"""filter topp on sorted scores."""
softmax_scores
=
scores
.
softmax
(
-
1
)
cum_scores
=
softmax_scores
.
cumsum
(
1
)
-
softmax_scores
mask
=
cum_scores
>
topp
[:,
None
]
mask
[:,
0
]
=
False
# keep at least one
if
inplace
:
scores
.
masked_fill_
(
mask
,
filter_value
)
else
:
scores
=
scores
.
masked_fill
(
mask
,
filter_value
)
return
scores
def
_multinomial_sampling
(
scores
:
torch
.
Tensor
,
seeds
:
torch
.
LongTensor
,
offsets
:
torch
.
LongTensor
,
indices
:
torch
.
LongTensor
=
None
):
"""sampling."""
from
lmdeploy.pytorch.kernels
import
multinomial_sampling
return
multinomial_sampling
(
scores
,
seeds
,
offsets
,
indices
)
@
dataclass
class
SamplingInputs
:
temperature
:
torch
.
Tensor
=
None
bad_words
:
torch
.
LongTensor
=
None
repetition_penalty
:
torch
.
Tensor
=
None
top_k
:
torch
.
LongTensor
=
None
top_p
:
torch
.
Tensor
=
None
random_seeds
:
int
=
None
random_offsets
:
int
=
None
max_top_k
:
int
=
1
min_top_p
:
float
=
1.0
@
classmethod
def
from_sampling_params
(
cls
,
seqs
:
List
[
SchedulerSequence
]):
"""from samplingg params."""
batch_size
=
len
(
seqs
)
temperature
=
[
None
]
*
batch_size
repetition_penalty
=
[
None
]
*
batch_size
top_k
=
[
None
]
*
batch_size
top_p
=
[
None
]
*
batch_size
bad_words
=
[
None
]
*
batch_size
random_seeds
=
[
torch
.
seed
()
&
0xffffffff
]
*
batch_size
random_offsets
=
[
None
]
*
batch_size
def
__gather_params
():
"""gather params."""
for
idx
,
seq
in
enumerate
(
seqs
):
param
=
seq
.
sampling_param
temperature
[
idx
]
=
param
.
temperature
repetition_penalty
[
idx
]
=
param
.
repetition_penalty
top_k
[
idx
]
=
param
.
top_k
top_p
[
idx
]
=
param
.
top_p
random_offsets
[
idx
]
=
seq
.
random_offsets
if
param
.
random_seed
is
not
None
:
random_seeds
[
idx
]
=
param
.
random_seed
&
0xffffffff
bw
=
param
.
bad_words
if
(
not
param
.
ignore_eos
and
seq
.
num_new_tokens
<
param
.
min_new_tokens
):
bw
=
bw
+
param
.
stop_words
bad_words
[
idx
]
=
bw
def
__get_topp
(
top_p
):
"""get topp."""
min_top_p
=
min
(
top_p
)
if
min_top_p
==
1.0
:
top_p
=
None
else
:
top_p
=
torch
.
tensor
(
top_p
)
return
top_p
,
min_top_p
def
__get_bad_words
(
bad_words
,
max_bw_len
):
"""get bad words."""
ret
=
torch
.
full
((
batch_size
,
max_bw_len
),
-
1
,
dtype
=
torch
.
int64
)
for
idx
,
bw
in
enumerate
(
bad_words
):
bw_len
=
len
(
bw
)
if
bw_len
==
0
:
continue
bw
=
ret
.
new_tensor
(
bw
)
ret
[
idx
,
:
bw_len
]
=
bw
return
ret
__gather_params
()
if
all
(
rp
==
1.0
for
rp
in
repetition_penalty
):
repetition_penalty
=
None
else
:
repetition_penalty
=
torch
.
tensor
(
repetition_penalty
)
temperature
=
torch
.
tensor
(
temperature
)
max_bw_len
=
max
(
len
(
bw
)
for
bw
in
bad_words
)
if
max_bw_len
==
0
:
bad_words
=
None
else
:
if
all
(
len
(
bw
)
==
max_bw_len
for
bw
in
bad_words
):
bad_words
=
torch
.
tensor
(
bad_words
)
else
:
bad_words
=
__get_bad_words
(
bad_words
,
max_bw_len
)
max_top_k
=
max
(
top_k
)
if
max_top_k
==
1
:
top_k
=
None
top_p
,
min_top_p
=
None
,
1.0
random_seeds
=
None
random_offsets
=
None
else
:
top_k
=
torch
.
tensor
(
top_k
)
top_p
,
min_top_p
=
__get_topp
(
top_p
)
random_seeds
=
torch
.
tensor
(
random_seeds
)
random_offsets
=
torch
.
tensor
(
random_offsets
)
sampling_input
=
cls
(
temperature
=
temperature
,
bad_words
=
bad_words
,
repetition_penalty
=
repetition_penalty
,
top_k
=
top_k
,
top_p
=
top_p
,
random_seeds
=
random_seeds
,
random_offsets
=
random_offsets
,
max_top_k
=
max_top_k
,
min_top_p
=
min_top_p
,
)
return
sampling_input
def
to_device
(
self
,
device
:
str
):
"""to device."""
input_dict
=
asdict
(
self
)
out_dict
=
dict
()
for
k
,
v
in
input_dict
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
=
v
.
to
(
device
)
out_dict
[
k
]
=
v
return
SamplingInputs
(
**
out_dict
)
class
SeedManager
:
"""random seed manager."""
def
__init__
(
self
):
self
.
_generators
:
Dict
[
int
,
torch
.
Generator
]
=
dict
()
def
new_generator
(
self
,
seed
:
int
,
device
:
str
=
'cuda'
):
"""new generator."""
return
torch
.
Generator
(
device
=
device
).
manual_seed
(
seed
)
def
get
(
self
,
seed
:
int
,
device
:
str
=
'cuda'
):
"""get generator."""
if
seed
not
in
self
.
_generators
:
generator
=
self
.
new_generator
(
seed
,
device
)
self
.
_generators
[
seed
]
=
generator
return
self
.
_generators
[
seed
]
SEED_MANAGER
=
SeedManager
()
class
FusedLogitsProcessor
(
LogitsWarper
):
"""Custom logits processor."""
def
__init__
(
self
,
sampling_inputs
:
SamplingInputs
):
self
.
sampling_inputs
:
SamplingInputs
=
sampling_inputs
def
__call__
(
self
,
input_ids
:
torch
.
LongTensor
,
scores
:
torch
.
FloatTensor
)
->
torch
.
FloatTensor
:
r
"""
Args:
input_ids (torch.LongTensor):
Indices of input sequence tokens in the vocabulary.
scores (torch.FloatTensor):
Prediction scores of a language modeling head.
These can be logits for each vocabulary when not using
beam search or log softmax for each vocabulary token
when using beam search
Return:
torch.FloatTensor: The processed prediction scores.
"""
sampling_inputs
=
self
.
sampling_inputs
scores
=
scores
.
clone
()
repetition_penalty
=
sampling_inputs
.
repetition_penalty
if
repetition_penalty
is
not
None
:
scores
=
_process_repetition_penalty
(
scores
,
input_ids
,
repetition_penalty
)
temperature
=
sampling_inputs
.
temperature
if
temperature
is
not
None
:
scores
=
_process_temperature
(
scores
,
temperature
)
bad_words
=
sampling_inputs
.
bad_words
if
bad_words
is
not
None
:
scores
=
_process_bad_words
(
scores
,
bad_words
)
return
scores
def
sampling
(
self
,
logits
:
torch
.
Tensor
):
"""sampling."""
sampling_inputs
=
self
.
sampling_inputs
def
__random_sampling
(
scores
:
torch
.
Tensor
,
indices
:
torch
.
LongTensor
):
"""random sampling."""
top_k
=
sampling_inputs
.
top_k
if
top_k
is
not
None
:
scores
=
_filter_topk_sorted
(
scores
,
top_k
)
top_p
=
sampling_inputs
.
top_p
if
top_p
is
not
None
:
scores
=
_filter_topp_sorted
(
scores
,
top_p
)
softmax_scores
=
scores
.
softmax
(
1
)
seeds
=
sampling_inputs
.
random_seeds
offsets
=
sampling_inputs
.
random_offsets
return
_multinomial_sampling
(
softmax_scores
,
seeds
,
offsets
,
indices
)
if
sampling_inputs
.
max_top_k
==
1
:
return
logits
.
argmax
(
-
1
)
else
:
scores
,
indices
=
logits
.
sort
(
1
,
descending
=
True
)
return
__random_sampling
(
scores
,
indices
)
lmdeploy/pytorch/engine/model_agent.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
asyncio
import
os
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Union
import
torch
import
torch.distributed
as
dist
from
torch
import
multiprocessing
as
mp
from
torch.distributed._tensor
import
DeviceMesh
,
Replicate
,
distribute_tensor
from
transformers
import
AutoModelForCausalLM
from
lmdeploy.pytorch.accel
import
LoadNoInit
from
lmdeploy.utils
import
get_logger
from
..adapter.adapter
import
(
AdapterWeightMap
,
get_indexed_lora_linears
,
get_max_lora_weight_size
,
update_lora_linears
)
from
..config
import
CacheConfig
,
ModelConfig
from
..models
import
patch
from
..utils
import
get_gpu_memory
from
.cache_engine
import
CacheEngine
logger
=
get_logger
(
'lmdeploy'
)
_PATCH_ARG_NAMES
=
[
'context'
,
'use_origin'
]
def
_infer_block_size
(
model
:
torch
.
nn
.
Module
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
world_size
:
int
=
1
):
"""infer block size."""
max_weight_dim
=
get_max_lora_weight_size
(
model
)
if
max_weight_dim
==
0
:
return
cache_config
.
block_size
per_token_size
=
model_config
.
get_head_size
(
)
*
model_config
.
num_key_value_heads
//
world_size
block_size
=
1
while
block_size
*
per_token_size
<
max_weight_dim
:
block_size
*=
2
return
block_size
*
world_size
def
_update_cache_config
(
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
gpu_id
:
int
=
0
,
host_mem_size
:
int
=
4
*
(
1
<<
30
),
world_size
:
int
=
1
):
"""Update the gpu mem and cpu mem according to model info.
Args:
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
gpu_id (int): The GPU id to use.
"""
def
__get_free_gpu_mem_size
():
"""get free gpu memory size."""
torch
.
cuda
.
empty_cache
()
gpu_mem_physical_free
,
_
=
get_gpu_memory
(
gpu_id
)
logger
.
debug
(
f
'device<
{
gpu_id
}
> free gpu memory:'
f
'
{
gpu_mem_physical_free
>>
20
}
mb'
)
vocal_size
=
model_config
.
vocab_size
max_prefill_token_num
=
cache_config
.
max_prefill_token_num
# lm_head output(2) + to float(4) + estimated misc(1) = 7
intermediate_cache_size
=
int
(
max_prefill_token_num
*
vocal_size
*
7
)
logger
.
debug
(
'estimated max runtime memory:'
f
'
{
intermediate_cache_size
>>
20
}
mb'
)
gpu_mem_physical_free
-=
intermediate_cache_size
return
gpu_mem_physical_free
*
cache_config
.
cache_max_entry_count
gpu_mem
=
__get_free_gpu_mem_size
()
cpu_mem
=
host_mem_size
cache_block_size
=
CacheEngine
.
get_cache_block_size
(
cache_config
.
block_size
,
model_config
,
world_size
)
if
cache_config
.
num_cpu_blocks
==
0
:
cache_config
.
num_cpu_blocks
=
int
(
cpu_mem
/
cache_block_size
)
if
cache_config
.
num_gpu_blocks
==
0
:
cache_config
.
num_gpu_blocks
=
int
(
gpu_mem
/
cache_block_size
)
cache_config
.
window_size
=
model_config
.
sliding_window
logger
.
debug
(
'block num: {}'
.
format
(
cache_config
.
num_gpu_blocks
))
@
dataclass
class
ModelInputs
:
"""Input of the model."""
input_ids
:
torch
.
LongTensor
seq_length
:
torch
.
LongTensor
attention_mask
:
torch
.
Tensor
block_offsets
:
torch
.
LongTensor
position_ids
:
torch
.
LongTensor
q_start_loc
:
torch
.
LongTensor
history_lengths
:
List
[
int
]
is_decoding
:
bool
local_adapter_ids
:
torch
.
LongTensor
=
None
global_adapter_ids
:
torch
.
LongTensor
=
None
adapter_offsets
:
torch
.
LongTensor
=
None
max_rank
:
int
=
0
meta
:
Any
=
None
def
slice
(
self
,
start
:
int
,
end
:
int
):
"""select by indices."""
sli
=
slice
(
start
,
end
)
start_loc
=
self
.
q_start_loc
[
sli
]
seq_length
=
self
.
seq_length
[
sli
]
end_loc
=
start_loc
[
-
1
]
+
seq_length
[
-
1
]
input_ids
=
self
.
input_ids
[:,
start_loc
[
0
]:
end_loc
]
start_loc
=
start_loc
-
start_loc
[
0
]
history_lengths
=
self
.
history_lengths
[
sli
]
local_adapter_ids
=
self
.
local_adapter_ids
if
local_adapter_ids
is
not
None
:
local_adapter_ids
=
local_adapter_ids
[
sli
]
return
ModelInputs
(
input_ids
=
input_ids
,
seq_length
=
seq_length
,
attention_mask
=
self
.
attention_mask
[
sli
],
block_offsets
=
self
.
block_offsets
[
sli
],
position_ids
=
self
.
position_ids
[
sli
],
q_start_loc
=
start_loc
,
history_lengths
=
history_lengths
,
is_decoding
=
self
.
is_decoding
,
local_adapter_ids
=
local_adapter_ids
,
global_adapter_ids
=
self
.
global_adapter_ids
,
adapter_offsets
=
self
.
adapter_offsets
,
max_rank
=
self
.
max_rank
,
meta
=
self
.
meta
)
def
split
(
self
,
split_size
:
int
,
block_size
:
int
):
"""split inputs."""
assert
len
(
self
.
seq_length
)
==
1
,
(
'Can not perform split on batched input.'
)
assert
split_size
%
block_size
==
0
,
(
'split_size should be multi of block_size.'
)
input_ids
=
self
.
input_ids
if
input_ids
.
numel
()
<
split_size
:
return
self
num_blocks
=
split_size
//
block_size
overlap
=
(
self
.
history_lengths
[
0
]
%
block_size
!=
0
)
max_seq_len
=
self
.
seq_length
[
0
].
item
()
ret
=
[]
block_start
=
0
history_len
=
self
.
history_lengths
[
0
]
for
i
in
range
(
0
,
max_seq_len
,
split_size
):
start
=
i
end
=
min
(
max_seq_len
,
i
+
split_size
)
block_end
=
block_start
+
num_blocks
if
overlap
:
block_end
+=
1
local_adapter_ids
=
self
.
local_adapter_ids
if
local_adapter_ids
is
not
None
:
local_adapter_ids
=
local_adapter_ids
[:,
start
:
end
]
inp
=
ModelInputs
(
input_ids
=
self
.
input_ids
[:,
start
:
end
],
seq_length
=
input_ids
.
new_tensor
([
end
-
start
]),
attention_mask
=
self
.
attention_mask
[:,
start
:
end
],
block_offsets
=
self
.
block_offsets
[:,
:
block_end
],
position_ids
=
self
.
position_ids
[:,
start
:
end
],
q_start_loc
=
input_ids
.
new_zeros
(
1
),
history_lengths
=
[
history_len
+
start
],
is_decoding
=
self
.
is_decoding
,
local_adapter_ids
=
local_adapter_ids
,
global_adapter_ids
=
self
.
global_adapter_ids
,
adapter_offsets
=
self
.
adapter_offsets
,
max_rank
=
self
.
max_rank
,
meta
=
self
.
meta
,
)
ret
.
append
(
inp
)
block_start
+=
num_blocks
return
ret
def
to_device
(
self
,
device
:
str
):
"""to device."""
input_dict
=
asdict
(
self
)
out_dict
=
dict
()
for
k
,
v
in
input_dict
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
=
v
.
to
(
device
)
out_dict
[
k
]
=
v
return
ModelInputs
(
**
out_dict
)
@
dataclass
class
StepContext
:
"""context of Model.
patched model might need extra information to perform inference. This
dataclass provide these infos and tools.
"""
inputs
:
ModelInputs
block_offsets
:
torch
.
LongTensor
position_ids
:
torch
.
LongTensor
position_ids_1d
:
torch
.
LongTensor
q_start_loc
:
torch
.
LongTensor
history_lengths
:
torch
.
LongTensor
q_seq_length
:
torch
.
LongTensor
kv_seq_length
:
torch
.
LongTensor
max_q_seq_length
:
int
max_kv_seq_length
:
int
kv_caches
:
List
is_decoding
:
bool
world_size
:
int
=
1
json_config
:
Dict
=
None
local_adapter_ids
:
torch
.
LongTensor
=
None
global_adapter_ids
:
torch
.
LongTensor
=
None
adapter_offsets
:
torch
.
LongTensor
=
None
max_rank
:
int
=
0
_outputs
:
Dict
=
field
(
default_factory
=
dict
)
@
classmethod
def
new
(
cls
,
inputs
:
ModelInputs
,
world_size
:
int
=
1
,
device
:
str
=
'cuda'
,
json_config
:
dict
=
None
,
kv_caches
:
List
=
None
,
):
"""build step context.
Args:
inputs (ModelInputs): packaged model inputs.
world_size (int): The distribution world size.
device (str): The device of the tensors.
"""
position_ids
=
inputs
.
position_ids
max_q_seq_length
=
position_ids
.
size
(
-
1
)
# seq_len + history_length
kv_seq_length
=
position_ids
[...,
-
1
]
+
1
# position ids 1d
q_seq_length
=
inputs
.
seq_length
position_ids_1d
=
cls
.
get_position_ids_1d
(
position_ids
,
q_seq_length
,
device
)
max_kv_seq_length
=
max_q_seq_length
+
max
(
inputs
.
history_lengths
)
ret
=
StepContext
(
inputs
=
inputs
,
block_offsets
=
inputs
.
block_offsets
,
position_ids
=
inputs
.
position_ids
,
position_ids_1d
=
position_ids_1d
,
q_start_loc
=
inputs
.
q_start_loc
,
history_lengths
=
inputs
.
history_lengths
,
q_seq_length
=
inputs
.
seq_length
,
kv_seq_length
=
kv_seq_length
,
max_q_seq_length
=
max_q_seq_length
,
max_kv_seq_length
=
max_kv_seq_length
,
kv_caches
=
kv_caches
,
is_decoding
=
inputs
.
is_decoding
,
world_size
=
world_size
,
json_config
=
json_config
,
local_adapter_ids
=
inputs
.
local_adapter_ids
,
global_adapter_ids
=
inputs
.
global_adapter_ids
,
adapter_offsets
=
inputs
.
adapter_offsets
,
max_rank
=
inputs
.
max_rank
)
return
ret
@
classmethod
def
tensorlize_block_offsets
(
cls
,
block_offsets
,
device
):
"""tensorlize block_offsets."""
import
numpy
as
np
offset_len
=
[
len
(
offset
)
for
offset
in
block_offsets
]
max_offsets_len
=
max
(
offset_len
)
batch_size
=
len
(
offset_len
)
pad_block_offsets
=
np
.
zeros
((
batch_size
,
max_offsets_len
),
dtype
=
np
.
int64
)
for
pad_offset
,
offset
,
off_len
in
zip
(
pad_block_offsets
,
block_offsets
,
offset_len
):
pad_offset
[:
off_len
]
=
offset
block_offsets
=
torch
.
from_numpy
(
pad_block_offsets
).
to
(
device
)
return
block_offsets
@
classmethod
def
get_position_ids_1d
(
cls
,
position_ids
:
torch
.
LongTensor
,
seq_length
:
torch
.
LongTensor
,
device
:
str
=
'cuda'
):
"""get 1d position_ids."""
if
position_ids
.
size
(
1
)
==
1
:
position_ids_1d
=
position_ids
.
flatten
()
else
:
position_ids_1d
=
[
ids
[:
l
]
for
ids
,
l
in
zip
(
position_ids
.
cpu
(),
seq_length
.
cpu
())
]
position_ids_1d
=
torch
.
cat
(
position_ids_1d
).
to
(
device
)
return
position_ids_1d
def
get_block_offsets
(
self
):
"""return block offsets."""
return
self
.
block_offsets
def
set_output
(
self
,
key
,
value
):
"""set output."""
self
.
_outputs
[
key
]
=
value
def
get_output
(
self
,
key
):
"""get output."""
if
key
in
self
.
_outputs
:
return
self
.
_outputs
[
key
]
return
None
def
cache_swapping
(
cache_engine
:
CacheEngine
,
swap_in_map
:
dict
,
swap_out_map
:
dict
):
"""perform cache swapping."""
issued_cache_op
=
False
if
len
(
swap_in_map
)
>
0
:
cache_engine
.
swap_in
(
swap_in_map
)
issued_cache_op
=
True
if
len
(
swap_out_map
)
>
0
:
cache_engine
.
swap_out
(
swap_out_map
)
issued_cache_op
=
True
if
issued_cache_op
:
cache_events
=
cache_engine
.
events
for
event
in
cache_events
:
event
.
wait
()
def
model_forward
(
patched_model
:
torch
.
nn
.
Module
,
inputs
:
ModelInputs
,
cache_engine
:
CacheEngine
,
json_config
:
dict
=
None
,
world_size
:
int
=
1
,
stream
:
torch
.
cuda
.
Stream
=
None
,
):
"""perform model forward."""
stream
=
stream
or
torch
.
cuda
.
current_stream
()
with
torch
.
inference_mode
(),
torch
.
cuda
.
stream
(
stream
):
# forward
inputs
=
inputs
.
to_device
(
'cuda'
)
context
=
StepContext
.
new
(
inputs
=
inputs
,
world_size
=
world_size
,
json_config
=
json_config
,
kv_caches
=
cache_engine
.
gpu_cache
,
)
output
=
patched_model
.
patched_forward
(
input_ids
=
inputs
.
input_ids
,
position_ids
=
inputs
.
position_ids
,
attention_mask
=
inputs
.
attention_mask
,
past_key_values
=
cache_engine
.
gpu_cache
,
return_dict
=
True
,
output_attentions
=
False
,
output_hidden_states
=
False
,
use_origin
=
False
,
context
=
context
,
)
return
dict
(
logits
=
output
[
'logits'
],
custom_outputs
=
context
.
_outputs
)
def
_load_adapters
(
hf_model
:
torch
.
nn
.
Module
,
adapters
:
Dict
[
str
,
str
],
device_map
:
str
=
'cpu'
):
"""load adapters."""
if
not
adapters
:
return
for
name
,
path
in
adapters
.
items
():
logger
.
info
(
f
'load adapter <
{
name
}
> from "
{
path
}
".'
)
hf_model
.
load_adapter
(
path
,
name
,
device_map
=
device_map
)
def
_add_adapters
(
hf_model
:
torch
.
nn
.
Module
,
adapters
:
Dict
[
str
,
str
]):
"""add adapters."""
if
not
adapters
:
return
from
peft
import
PeftConfig
,
inject_adapter_in_model
for
name
,
path
in
adapters
.
items
():
config
=
PeftConfig
.
from_pretrained
(
path
)
inject_adapter_in_model
(
config
,
model
=
hf_model
,
adapter_name
=
name
)
def
_unparam_lora_weight
(
model
:
torch
.
nn
.
Module
):
"""unparam lora weight.
We don't want to move weight of lora to gpu.
"""
from
peft.tuners.lora
import
Linear
as
LoRALinear
def
_tensorize_weight
(
linear
):
"""tensorize weight."""
w
=
linear
.
weight
del
linear
.
weight
linear
.
weight
=
w
.
data
for
_
,
mod
in
model
.
named_modules
():
if
isinstance
(
mod
,
LoRALinear
):
lora_A
=
mod
.
lora_A
lora_B
=
mod
.
lora_B
for
linear
in
lora_A
.
values
():
_tensorize_weight
(
linear
)
for
linear
in
lora_B
.
values
():
_tensorize_weight
(
linear
)
SwapMap
=
Dict
[
int
,
int
]
class
AutoModelAgent
:
"""Base model agent."""
def
__init__
(
self
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
):
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
def
paging_adapters
(
self
,
weight_maps
:
List
[
AdapterWeightMap
]):
"""paging adapter."""
raise
NotImplementedError
(
'Not implemented.'
)
async
def
async_forward
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
SwapMap
,
swap_out_map
:
SwapMap
):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
raise
NotImplementedError
(
'Not implemented.'
)
def
forward
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
SwapMap
,
swap_out_map
:
SwapMap
):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
raise
NotImplementedError
(
'Not implemented.'
)
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
,
cache_config
:
CacheConfig
,
trust_remote_code
:
bool
,
adapters
:
Dict
[
str
,
str
]
=
None
,
tp
:
int
=
1
):
"""from pretrained."""
return
build_model_agent
(
pretrained_model_name_or_path
,
cache_config
=
cache_config
,
trust_remote_code
=
trust_remote_code
,
adapters
=
adapters
,
tp
=
tp
)
class
BaseModelAgent
(
AutoModelAgent
):
"""Base model agent.
load model on local gpu
Args:
model_path (str): The hugging face model path.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
trust_remote_code (bool): Trust remote code
"""
def
__init__
(
self
,
model_path
:
str
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
adapters
:
Dict
[
str
,
str
]
=
None
,
trust_remote_code
:
bool
=
True
):
super
().
__init__
(
model_config
=
model_config
,
cache_config
=
cache_config
)
torch_dtype
=
model_config
.
dtype
self
.
patched_model
=
self
.
_build_model
(
model_path
,
torch_dtype
=
torch_dtype
,
adapters
=
adapters
,
trust_remote_code
=
trust_remote_code
)
block_size
=
_infer_block_size
(
self
.
patched_model
,
model_config
,
cache_config
)
if
block_size
!=
cache_config
.
block_size
:
cache_config
.
block_size
=
block_size
logger
.
warning
(
f
'infered block size:
{
block_size
}
'
)
_update_cache_config
(
model_config
,
cache_config
)
self
.
cache_engine
=
CacheEngine
(
cache_config
,
model_config
)
self
.
stream
=
torch
.
cuda
.
Stream
()
def
_build_model
(
self
,
model_path
:
str
,
torch_dtype
:
torch
.
dtype
,
adapters
:
Dict
[
str
,
str
]
=
None
,
trust_remote_code
:
bool
=
True
):
"""build patched model."""
with
LoadNoInit
():
hf_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
,
**
self
.
model_config
.
init_kwargs
)
hf_model
.
eval
()
hf_model
.
config
.
use_cache
=
True
if
adapters
:
_load_adapters
(
hf_model
,
adapters
)
patched_model
=
patch
(
hf_model
,
_PATCH_ARG_NAMES
)
if
adapters
:
_unparam_lora_weight
(
patched_model
)
patched_model
=
patched_model
.
cuda
()
return
patched_model
def
paging_adapters
(
self
,
weight_maps
:
List
[
AdapterWeightMap
]):
"""paging adapter."""
logger
.
info
(
'paging adapters.'
)
lora_linears
=
get_indexed_lora_linears
(
self
.
patched_model
)
cpu_caches
=
self
.
cache_engine
.
cpu_cache
num_blocks
=
self
.
cache_engine
.
num_cpu_blocks
cpu_caches
=
[(
kcache
.
view
(
num_blocks
,
-
1
),
vcache
.
view
(
num_blocks
,
-
1
))
for
kcache
,
vcache
in
cpu_caches
]
for
weight_map
in
weight_maps
:
weight_map
.
cache_adapter
(
lora_linears
,
cpu_caches
)
update_lora_linears
(
lora_linears
,
weight_maps
,
device
=
'cuda'
)
def
_forward_impl
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
SwapMap
,
swap_out_map
:
SwapMap
):
cache_swapping
(
self
.
cache_engine
,
swap_in_map
=
swap_in_map
,
swap_out_map
=
swap_out_map
)
output
=
model_forward
(
self
.
patched_model
,
inputs
,
self
.
cache_engine
,
self
.
model_config
.
json_config
,
world_size
=
1
,
stream
=
self
.
stream
,
)
return
output
def
forward
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
SwapMap
,
swap_out_map
:
SwapMap
):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output
=
self
.
_forward_impl
(
inputs
,
swap_in_map
=
swap_in_map
,
swap_out_map
=
swap_out_map
)
self
.
stream
.
synchronize
()
return
output
async
def
async_forward
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
SwapMap
,
swap_out_map
:
SwapMap
):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (SwapMap): Cache maps to swap in.
swap_out_map (SwapMap): Cache maps to swap out.
"""
output
=
self
.
_forward_impl
(
inputs
,
swap_in_map
=
swap_in_map
,
swap_out_map
=
swap_out_map
)
await
asyncio
.
get_event_loop
().
run_in_executor
(
None
,
self
.
stream
.
synchronize
)
return
output
@
dataclass
class
TPResponse
:
ret_code
:
int
error
:
Union
[
Exception
,
List
[
Exception
]]
=
None
data
:
Any
=
None
def
gather_error
(
self
):
"""gather error."""
rank
=
dist
.
get_rank
()
world_size
=
dist
.
get_world_size
()
# gather errors
error_count
=
torch
.
tensor
(
self
.
ret_code
).
cuda
(
rank
)
dist
.
all_reduce
(
error_count
)
if
error_count
.
item
()
>
0
:
all_errors
=
[
None
]
*
world_size
dist
.
all_gather_object
(
all_errors
,
self
.
error
)
self
.
ret_code
=
1
self
.
error
=
all_errors
def
raise_error
(
self
,
default_error
:
Exception
):
"""raise error."""
if
self
.
error
is
None
:
raise
default_error
elif
isinstance
(
self
.
error
,
Exception
):
raise
self
.
error
else
:
assert
isinstance
(
self
.
error
,
List
),
(
'expect error type list, '
f
'got
{
type
(
self
.
error
)
}
'
)
rank
=
dist
.
get_rank
()
err
=
self
.
error
[
rank
]
if
err
is
None
:
raise
default_error
else
:
raise
err
def
_get_model_memory_usage
(
model
:
torch
.
nn
.
Module
)
->
int
:
"""get model memory usage."""
size
=
0
for
_
,
param
in
model
.
named_parameters
():
size
+=
param
.
element_size
()
*
param
.
numel
()
for
_
,
buf
in
model
.
named_buffers
():
size
+=
buf
.
element_size
()
*
param
.
numel
()
return
size
def
_create_device_map
(
model
:
torch
.
nn
.
Module
,
world_size
:
int
,
device_map
:
dict
=
None
):
"""Distribute params to each devices."""
if
device_map
is
None
:
device_map
=
dict
()
device_id
=
0
for
name
,
_
in
model
.
named_parameters
():
device_map
[
name
]
=
device_id
device_id
=
(
device_id
+
1
)
%
world_size
for
name
,
_
in
model
.
named_buffers
():
device_map
[
name
]
=
device_id
device_id
=
(
device_id
+
1
)
%
world_size
return
device_map
def
_tp_build_model
(
rank
:
int
,
model_path
:
str
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
adapters
:
Dict
[
str
,
str
],
out_que
:
mp
.
Queue
,
world_size
:
int
,
trust_remote_code
=
True
,
):
"""build tensor parallel model."""
from
accelerate
import
init_empty_weights
error_code
=
0
error_type
=
None
patched_model
=
None
cache_engine
=
None
def
__get_device_map
(
model
,
device_map
=
None
):
"""get device map of model."""
import
psutil
model_size
=
_get_model_memory_usage
(
model
)
if
psutil
.
virtual_memory
().
available
<
model_size
:
logger
.
debug
(
'Preload model on GPU.'
)
return
device_map
else
:
logger
.
debug
(
'Preload model on CPU.'
)
return
'cpu'
def
__load_params_and_buffers
(
param_mod
,
mod
):
"""load param and buffer."""
for
name
,
param
in
param_mod
.
named_parameters
(
recurse
=
False
):
mod
.
register_parameter
(
name
,
param
)
for
name
,
buffer
in
param_mod
.
named_buffers
(
recurse
=
False
):
mod
.
register_buffer
(
name
,
buffer
)
def
__load_state_dict_assign
(
param_model
,
model
):
"""load state dict assign."""
try
:
model
.
load_state_dict
(
param_model
.
state_dict
(),
assign
=
True
)
except
Exception
:
__load_params_and_buffers
(
param_model
,
model
)
mods
=
dict
(
model
.
named_modules
())
for
mod_name
,
param_mod
in
param_model
.
named_modules
():
mod
=
mods
[
mod_name
]
__load_params_and_buffers
(
param_mod
,
mod
)
def
_broadcast_config
(
cache_config
):
"""broadcast cache config, use minimum cache."""
if
rank
==
0
:
gathered_configs
=
[
None
]
*
world_size
dist
.
gather_object
(
cache_config
,
gathered_configs
)
num_gpu_blocks_list
=
[
config
.
num_gpu_blocks
for
config
in
gathered_configs
]
num_cpu_blocks_list
=
[
config
.
num_cpu_blocks
for
config
in
gathered_configs
]
min_num_gpu_blocks
=
min
(
num_gpu_blocks_list
)
min_num_cpu_blocks
=
min
(
num_cpu_blocks_list
)
cache_config
.
num_cpu_blocks
=
min_num_cpu_blocks
cache_config
.
num_gpu_blocks
=
min_num_gpu_blocks
config_list
=
[
cache_config
]
else
:
gathered_configs
=
None
dist
.
gather_object
(
cache_config
,
gathered_configs
)
config_list
=
[
None
]
dist
.
broadcast_object_list
(
config_list
)
return
config_list
[
0
]
try
:
config
=
model_config
.
hf_config
torch_dtype
=
model_config
.
dtype
device_map
=
None
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
,
**
model_config
.
init_kwargs
)
if
rank
==
0
:
device_map
=
_create_device_map
(
model
,
world_size
)
_add_adapters
(
model
,
adapters
)
if
rank
==
0
:
# adapter would remove weight of linear.
device_map
=
_create_device_map
(
model
,
world_size
,
device_map
)
model
.
eval
()
model
.
config
.
use_cache
=
True
if
rank
==
0
:
with
LoadNoInit
():
device_map
=
__get_device_map
(
model
,
device_map
)
param_model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
torch_dtype
=
torch_dtype
,
device_map
=
device_map
,
trust_remote_code
=
trust_remote_code
,
**
model_config
.
init_kwargs
)
_load_adapters
(
param_model
,
adapters
,
device_map
=
device_map
)
__load_state_dict_assign
(
param_model
,
model
)
param_model
=
param_model
.
to
(
'meta'
)
del
param_model
patched_model
=
patch
(
model
,
extra_args
=
_PATCH_ARG_NAMES
,
rank
=
rank
,
world_size
=
world_size
,
)
block_size
=
_infer_block_size
(
patched_model
,
model_config
,
cache_config
,
world_size
)
if
block_size
!=
cache_config
.
block_size
:
cache_config
.
block_size
=
block_size
if
rank
==
0
:
logger
.
warning
(
f
'infered block size:
{
block_size
}
'
)
_update_cache_config
(
model_config
,
cache_config
,
gpu_id
=
rank
,
world_size
=
world_size
)
cache_config
=
_broadcast_config
(
cache_config
)
cache_engine
=
CacheEngine
(
cache_config
,
model_config
,
rank
=
rank
,
world_size
=
world_size
)
except
Exception
as
e
:
logger
.
error
(
f
'rank[
{
rank
}
] failed with error:
{
e
}
'
)
error_code
=
1
error_type
=
e
# response
resp
=
TPResponse
(
error_code
,
error_type
,
cache_config
)
resp
.
gather_error
()
if
rank
==
0
:
out_que
.
put
(
resp
)
if
resp
.
ret_code
!=
0
:
resp
.
raise_error
(
RuntimeError
(
'failed to init model.'
))
return
patched_model
,
cache_engine
def
_tp_get_input
(
rank
:
int
,
in_que
:
mp
.
Queue
,
world_size
:
int
):
"""get input tensor parallel."""
device_mesh
=
DeviceMesh
(
'cuda'
,
list
(
range
(
world_size
)))
# broadcast meta info
if
rank
==
0
:
inputs
,
swap_in_map
,
swap_out_map
=
in_que
.
get
()
inputs
=
asdict
(
inputs
)
input_tensors
=
dict
(
(
k
,
v
)
for
k
,
v
in
inputs
.
items
()
if
isinstance
(
v
,
torch
.
Tensor
))
tensor_metas
=
dict
(
(
name
,
(
t
.
shape
,
t
.
dtype
))
for
name
,
t
in
input_tensors
.
items
())
other_metas
=
dict
((
k
,
v
)
for
k
,
v
in
inputs
.
items
()
if
not
isinstance
(
v
,
torch
.
Tensor
))
input_metas
=
(
tensor_metas
,
other_metas
)
objs
=
[
input_metas
,
swap_in_map
,
swap_out_map
]
else
:
objs
=
[
None
,
None
,
None
]
dist
.
broadcast_object_list
(
objs
)
if
rank
!=
0
:
input_metas
=
objs
[
0
]
tensor_metas
,
other_metas
=
input_metas
input_tensors
=
dict
((
name
,
torch
.
empty
(
meta
[
0
],
dtype
=
meta
[
1
]))
for
name
,
meta
in
tensor_metas
.
items
())
updated_inputs
=
dict
()
for
name
,
t
in
input_tensors
.
items
():
updated_inputs
[
name
]
=
distribute_tensor
(
t
,
device_mesh
=
device_mesh
,
placements
=
[
Replicate
()
]).
to_local
()
torch
.
cuda
.
synchronize
()
inputs
=
updated_inputs
inputs
.
update
(
other_metas
)
inputs
=
ModelInputs
(
**
inputs
)
swap_in_map
=
objs
[
1
]
swap_out_map
=
objs
[
2
]
return
inputs
,
swap_in_map
,
swap_out_map
def
_tp_paging_adapters
(
rank
:
int
,
patched_model
:
torch
.
nn
.
Module
,
cache_engine
:
CacheEngine
,
in_que
:
mp
.
Queue
,
out_que
:
mp
.
Queue
,
):
"""tp paging adapters."""
def
__get_weight_map
():
"""get weight map."""
if
rank
==
0
:
weight_maps
=
in_que
.
get
()
dist_obj
=
[
weight_maps
]
else
:
dist_obj
=
[
None
]
dist
.
broadcast_object_list
(
dist_obj
)
return
dist_obj
[
0
]
def
__paging
(
weight_maps
):
"""paging."""
lora_linears
=
get_indexed_lora_linears
(
patched_model
)
cpu_caches
=
cache_engine
.
cpu_cache
num_blocks
=
cache_engine
.
num_cpu_blocks
cpu_caches
=
[(
kcache
.
view
(
num_blocks
,
-
1
),
vcache
.
view
(
num_blocks
,
-
1
))
for
kcache
,
vcache
in
cpu_caches
]
for
weight_map
in
weight_maps
:
weight_map
.
cache_adapter
(
lora_linears
,
cpu_caches
)
update_lora_linears
(
lora_linears
,
weight_maps
,
device
=
'cuda'
)
weight_maps
=
__get_weight_map
()
resp
=
TPResponse
(
0
)
try
:
if
rank
==
0
:
logger
.
info
(
'tp paging adapters.'
)
if
len
(
weight_maps
)
>
0
:
__paging
(
weight_maps
)
except
Exception
as
e
:
resp
.
ret_code
=
1
resp
.
error
=
e
resp
.
gather_error
()
if
rank
==
0
:
out_que
.
put
(
resp
)
if
resp
.
ret_code
!=
0
:
resp
.
raise_error
(
RuntimeError
(
'tp paging adapters failed.'
))
def
_tp_model_loop
(
rank
:
int
,
model_path
:
str
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
adapters
:
Dict
[
str
,
str
],
in_que
:
mp
.
Queue
,
out_que
:
mp
.
Queue
,
world_size
:
int
,
trust_remote_code
=
True
,
):
"""Start model loops for tensor parallel model inference.
Args:
rank (int): Distribution rank.
model_path (int): Path of the hugging face model. Could be
local or online.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache.
in_que (mp.Queue): Input queue. Used to receive model input.
out_que (mp.Queue): Output queue. Used to send the model output.
world_size (int): The distribution world size.
"""
stream
=
torch
.
cuda
.
Stream
()
patched_model
,
cache_engine
=
_tp_build_model
(
rank
,
model_path
,
model_config
,
cache_config
,
adapters
,
out_que
=
out_que
,
world_size
=
world_size
,
trust_remote_code
=
trust_remote_code
)
if
adapters
:
_tp_paging_adapters
(
rank
,
patched_model
,
cache_engine
=
cache_engine
,
in_que
=
in_que
,
out_que
=
out_que
)
while
True
:
inputs
,
swap_in_map
,
swap_out_map
=
_tp_get_input
(
rank
,
in_que
,
world_size
)
cache_swapping
(
cache_engine
,
swap_in_map
=
swap_in_map
,
swap_out_map
=
swap_out_map
)
output
=
model_forward
(
patched_model
,
inputs
,
cache_engine
,
model_config
.
json_config
,
world_size
=
world_size
,
stream
=
stream
,
)
stream
.
synchronize
()
if
rank
==
0
:
resp_output
=
output
out_que
.
put
(
TPResponse
(
0
,
None
,
resp_output
))
def
_start_tp_process
(
rank
:
int
,
world_size
:
int
,
func
:
Callable
,
args
:
List
=
None
,
kwargs
:
Dict
=
None
,
port
:
int
=
29500
):
"""Start the tensor parallel process.
Args:
rank (int): The distribution rank.
world_size (int): The distribution world size.
func (Callable): The function to be called in the process.
args (List): The arguments of the func.
kwargs (Dict): The keyword arguments of the func.
"""
try
:
os
.
environ
[
'MASTER_ADDR'
]
=
'127.0.0.1'
os
.
environ
[
'MASTER_PORT'
]
=
str
(
port
)
dist
.
init_process_group
(
'nccl'
,
rank
=
rank
,
world_size
=
world_size
)
with
torch
.
cuda
.
device
(
rank
),
torch
.
no_grad
():
args
=
args
or
tuple
()
kwargs
=
kwargs
or
dict
()
func
(
rank
,
*
args
,
**
kwargs
)
except
Exception
as
e
:
from
traceback
import
print_exc
logger
.
error
(
f
'Rank[
{
rank
}
] failed.'
)
print_exc
()
raise
e
def
_check_context_alive
(
mp_context
:
mp
.
ProcessContext
):
"""check context alive."""
procs
=
mp_context
.
processes
for
idx
,
p
in
enumerate
(
procs
):
if
not
p
.
is_alive
():
raise
RuntimeError
(
f
'Rank[
{
idx
}
] failed.'
)
def
_queue_get_response
(
que
:
mp
.
Queue
,
mp_context
:
mp
.
ProcessContext
,
interval
:
float
=
1.0
):
"""get response."""
from
multiprocessing.queues
import
Empty
while
True
:
try
:
return
que
.
get
(
timeout
=
interval
)
except
Empty
:
_check_context_alive
(
mp_context
)
async
def
_async_queue_get_response
(
que
:
mp
.
Queue
,
mp_context
:
mp
.
ProcessContext
,
interval
:
float
=
1.0
):
"""get response."""
from
multiprocessing.queues
import
Empty
def
__try_que_get
():
"""try que get."""
try
:
return
que
.
get
(
timeout
=
interval
)
except
Empty
:
return
None
while
True
:
ret
=
await
asyncio
.
get_event_loop
().
run_in_executor
(
None
,
__try_que_get
)
if
ret
is
not
None
:
return
ret
_check_context_alive
(
mp_context
)
class
TPModelAgent
(
AutoModelAgent
):
"""Tensor Parallelism model agent.
load model on multiple GPUs
Args:
model_path (str): The hugging face model path.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache info.
trust_remote_code (bool): Trust remote code
"""
def
__init__
(
self
,
model_path
:
str
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
world_size
:
int
,
adapters
:
Dict
[
str
,
str
]
=
None
,
trust_remote_code
:
bool
=
True
)
->
None
:
self
.
mp_ctx
=
mp
.
get_context
(
'spawn'
)
super
().
__init__
(
model_config
=
model_config
,
cache_config
=
cache_config
)
self
.
world_size
=
world_size
self
.
tp_model_in_que
=
self
.
mp_ctx
.
Queue
(
10
)
self
.
tp_model_out_que
=
self
.
mp_ctx
.
Queue
(
10
)
self
.
patch_model_tp
(
model_path
,
model_config
=
model_config
,
cache_config
=
cache_config
,
adapters
=
adapters
,
in_que
=
self
.
tp_model_in_que
,
out_que
=
self
.
tp_model_out_que
,
world_size
=
world_size
,
trust_remote_code
=
trust_remote_code
)
def
patch_model_tp
(
self
,
model_path
:
str
,
model_config
:
ModelConfig
,
cache_config
:
CacheConfig
,
adapters
:
Dict
[
str
,
str
],
in_que
:
mp
.
Queue
,
out_que
:
mp
.
Queue
,
world_size
:
int
,
trust_remote_code
:
bool
):
"""Start tensor parallel sub process.
Args:
model_path (int): Path of the hugging face model.
Could be local or online.
extra_args (List[str]): The extra arguments to add to the
patched model.
model_config (ModelConfig): The config of the model.
cache_config (CacheConfig): The config of the cache.
in_que (mp.Queue): Input queue. Used to receive model input.
out_que (mp.Queue): Output queue. Used to send the model output.
world_size (int): The distribution world size.
"""
def
__find_available_port
()
->
bool
:
"""find available port."""
import
socket
port
=
29500
while
True
:
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
if
s
.
connect_ex
((
'localhost'
,
port
))
!=
0
:
return
port
port
+=
1
self
.
mp_context
=
mp
.
spawn
(
_start_tp_process
,
args
=
(
world_size
,
_tp_model_loop
,
(
model_path
,
),
dict
(
model_config
=
model_config
,
cache_config
=
cache_config
,
adapters
=
adapters
,
in_que
=
in_que
,
out_que
=
out_que
,
world_size
=
world_size
,
trust_remote_code
=
trust_remote_code
),
__find_available_port
(),
),
nprocs
=
world_size
,
join
=
False
,
daemon
=
True
,
)
resp
:
TPResponse
=
_queue_get_response
(
out_que
,
self
.
mp_context
)
if
resp
.
ret_code
!=
0
:
logger
.
error
(
f
'Init tp model failed with error:
{
resp
.
error
}
'
)
raise
next
(
err
for
err
in
resp
.
error
if
err
is
not
None
)
self
.
cache_config
=
resp
.
data
def
paging_adapters
(
self
,
weight_maps
:
List
[
AdapterWeightMap
]):
"""load adapter."""
if
not
weight_maps
:
return
self
.
tp_model_in_que
.
put
(
weight_maps
)
resp
:
TPResponse
=
self
.
tp_model_out_que
.
get
()
if
resp
.
ret_code
!=
0
:
logger
.
error
(
f
'paging adapters failed with error:
{
resp
.
error
}
'
)
raise
next
(
err
for
err
in
resp
.
error
if
err
is
not
None
)
def
forward
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
SwapMap
,
swap_out_map
:
SwapMap
):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (Dict[int, int]): Cache maps to swap in.
swap_out_map (Dict[int, int]): Cache maps to swap out.
"""
with
torch
.
no_grad
():
self
.
tp_model_in_que
.
put
((
inputs
,
swap_in_map
,
swap_out_map
))
resp
:
TPResponse
=
_queue_get_response
(
self
.
tp_model_out_que
,
self
.
mp_context
)
if
resp
.
ret_code
!=
0
:
raise
RuntimeError
(
'tp forward failed.'
)
return
resp
.
data
async
def
async_forward
(
self
,
inputs
:
ModelInputs
,
swap_in_map
:
SwapMap
,
swap_out_map
:
SwapMap
):
"""model forward.
Args:
inputs (Dict): The input data comes from _make_inputs.
swap_in_map (Dict[int, int]): Cache maps to swap in.
swap_out_map (Dict[int, int]): Cache maps to swap out.
"""
with
torch
.
no_grad
():
self
.
tp_model_in_que
.
put
((
inputs
,
swap_in_map
,
swap_out_map
))
resp
:
TPResponse
=
await
_async_queue_get_response
(
self
.
tp_model_out_que
,
self
.
mp_context
)
if
resp
.
ret_code
!=
0
:
raise
RuntimeError
(
'tp forward failed.'
)
return
resp
.
data
def
build_model_agent
(
model_path
:
str
,
cache_config
:
CacheConfig
,
trust_remote_code
:
bool
,
adapters
:
Dict
[
str
,
str
]
=
None
,
tp
:
int
=
1
):
"""create model agent."""
model_config
=
ModelConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
if
tp
==
1
:
model_agent
=
BaseModelAgent
(
model_path
,
model_config
=
model_config
,
cache_config
=
cache_config
,
adapters
=
adapters
,
trust_remote_code
=
trust_remote_code
)
else
:
model_agent
=
TPModelAgent
(
model_path
,
model_config
=
model_config
,
cache_config
=
cache_config
,
world_size
=
tp
,
adapters
=
adapters
,
trust_remote_code
=
trust_remote_code
)
return
model_agent
lmdeploy/pytorch/engine/request.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
asyncio
import
enum
from
dataclasses
import
dataclass
,
field
from
queue
import
Empty
,
Queue
from
threading
import
Lock
,
Thread
from
typing
import
Any
,
Awaitable
,
Callable
,
Dict
,
List
from
lmdeploy.messages
import
ResponseType
from
lmdeploy.utils
import
get_logger
logger
=
get_logger
(
'lmdeploy'
)
def
_raise_exception_on_finish
(
task
:
asyncio
.
Task
)
->
None
:
try
:
task
.
result
()
except
asyncio
.
CancelledError
:
return
except
Exception
as
e
:
logger
.
exception
(
f
'Engine loop failed with error:
{
e
}
'
)
def
_ignore_exception_on_finish
(
task
:
asyncio
.
Task
)
->
None
:
try
:
task
.
result
()
except
asyncio
.
CancelledError
:
return
except
Exception
as
exc
:
logger
.
info
(
f
'task:
{
task
.
get_name
()
}
ended.'
)
logger
.
debug
(
f
'task:
{
task
.
get_name
()
}
exception:
{
exc
}
'
)
class
RequestType
(
enum
.
Enum
):
"""Request type."""
ADD_SESSION
=
enum
.
auto
()
ADD_MESSAGE
=
enum
.
auto
()
STOP_SESSION
=
enum
.
auto
()
END_SESSION
=
enum
.
auto
()
STOP_ENGINE
=
enum
.
auto
()
RESUME_ENGINE
=
enum
.
auto
()
@
dataclass
class
Request
:
"""Request."""
type
:
RequestType
sender_id
:
int
req_id
:
int
data
:
Any
=
None
@
dataclass
class
Response
:
"""Response."""
type
:
ResponseType
sender_id
:
int
req_id
:
int
data
:
Any
=
None
err_msg
:
str
=
''
ReqList
=
List
[
Request
]
def
_run_until_complete
(
future
:
Awaitable
):
"""run untile complete."""
try
:
event_loop
=
asyncio
.
get_event_loop
()
except
Exception
:
logger
.
warning
(
'Can not found event loop in current thread.'
' Create a new event loop.'
)
event_loop
=
asyncio
.
new_event_loop
()
asyncio
.
set_event_loop
(
event_loop
)
return
event_loop
.
run_until_complete
(
future
)
@
dataclass
class
RequestSender
:
"""Request sender.
Args:
sender_id (int): The id of the sender
"""
sender_id
:
int
manager
:
'RequestManager'
resp_dict
:
Dict
[
int
,
List
[
Response
]]
=
field
(
default_factory
=
dict
)
_next_req_id
:
int
=
0
_resp_que
:
asyncio
.
Queue
=
None
_resp_thread_que
:
Queue
=
None
@
classmethod
def
new
(
cls
,
sender_id
:
int
,
manager
:
'RequestManager'
):
"""new."""
return
cls
(
sender_id
=
sender_id
,
manager
=
manager
)
@
property
def
resp_que
(
self
):
"""response queue."""
if
self
.
is_thread_safe
():
return
self
.
manager
.
responses
if
self
.
manager
.
_loop_task
is
None
and
not
self
.
is_thread_safe
():
self
.
manager
.
create_loop_task
()
if
self
.
_resp_que
is
None
:
self
.
_resp_que
=
asyncio
.
Queue
()
return
self
.
_resp_que
@
property
def
req_que
(
self
):
"""request queue."""
return
self
.
manager
.
requests
@
property
def
resp_thread_que
(
self
):
"""response threadsafe queue."""
if
self
.
_resp_thread_que
is
None
:
self
.
_resp_thread_que
=
Queue
()
return
self
.
_resp_thread_que
@
property
def
req_thread_que
(
self
):
"""request threadsafe queue."""
return
self
.
manager
.
thread_requests
@
property
def
event_loop
(
self
):
"""get event loop."""
return
self
.
manager
.
event_loop
def
is_thread_safe
(
self
):
"""is thread safe."""
return
self
.
manager
.
is_thread_safe
()
def
is_loop_alive
(
self
):
"""is loop alive."""
return
self
.
manager
.
is_loop_alive
()
def
run_until_complete
(
self
,
future
:
Awaitable
):
"""run untile complete."""
return
self
.
manager
.
run_until_complete
(
future
)
def
_resp_get
(
self
):
"""resp_que.get."""
timeout
=
1
while
True
:
if
not
self
.
manager
.
is_loop_alive
():
logger
.
debug
(
'Engine loop is not alive.'
)
exit
(
1
)
try
:
ret
=
self
.
resp_thread_que
.
get
(
timeout
=
timeout
)
return
ret
except
Empty
:
continue
except
Exception
as
e
:
logger
.
exception
(
f
'sender[
{
self
.
sender_id
}
] get response failed:
{
e
}
'
)
raise
e
async
def
_async_resp_get
(
self
):
"""get resp.
Different behavior in threadsafe mode.
"""
timeout
=
1
async
def
__no_threadsafe_get
():
while
True
:
if
not
self
.
manager
.
is_loop_alive
():
logger
.
debug
(
'Engine loop is not alive.'
)
exit
(
1
)
try
:
return
await
asyncio
.
wait_for
(
self
.
resp_que
.
get
(),
timeout
)
except
asyncio
.
TimeoutError
:
continue
except
Exception
as
e
:
logger
.
exception
(
f
'sender[
{
self
.
sender_id
}
] get response failed:
{
e
}
'
)
raise
e
if
self
.
is_thread_safe
():
ret
=
self
.
_resp_get
()
await
asyncio
.
sleep
(
0
)
return
ret
else
:
return
await
__no_threadsafe_get
()
def
_req_put
(
self
,
reqs
:
Any
):
"""req put."""
self
.
req_thread_que
.
put
(
reqs
)
async
def
_async_req_put
(
self
,
reqs
:
Any
):
"""async rq_que put.
Different behavior in threadsafe mode.
"""
if
self
.
is_thread_safe
():
self
.
_req_put
(
reqs
)
await
asyncio
.
sleep
(
0
)
else
:
await
self
.
req_que
.
put
(
reqs
)
def
_prefetch_resps
(
self
):
"""prefetch from resp que.
Different behavior in threadsafe mode.
"""
if
self
.
is_thread_safe
():
resp_que
=
self
.
resp_thread_que
else
:
resp_que
=
self
.
resp_que
num_resps
=
resp_que
.
qsize
()
for
_
in
range
(
num_resps
):
resp
:
Response
=
resp_que
.
get_nowait
()
req_id
=
resp
.
req_id
self
.
_push_resp
(
req_id
,
resp
)
def
_push_resp
(
self
,
req_id
:
int
,
resp
:
Response
):
"""push response."""
self
.
resp_dict
.
setdefault
(
req_id
,
[])
self
.
resp_dict
[
req_id
].
append
(
resp
)
def
_pop_resp
(
self
,
req_id
:
int
,
default
:
Any
=
None
):
"""pop response."""
if
req_id
not
in
self
.
resp_dict
:
return
default
resps
=
self
.
resp_dict
[
req_id
]
ret
=
resps
.
pop
(
0
)
if
len
(
resps
)
==
0
:
self
.
resp_dict
.
pop
(
req_id
)
return
ret
def
_gather_request
(
self
,
req_types
:
List
[
RequestType
],
data
:
List
[
Any
]):
"""gather requests."""
if
self
.
manager
.
_loop_task
is
None
and
not
self
.
is_thread_safe
():
self
.
manager
.
create_loop_task
()
if
not
self
.
is_loop_alive
():
logger
.
error
(
'Engine main loop stopped.'
)
exit
(
1
)
assert
len
(
req_types
)
==
len
(
data
)
batch_size
=
len
(
req_types
)
req_ids
=
list
(
range
(
self
.
_next_req_id
,
self
.
_next_req_id
+
batch_size
))
self
.
_next_req_id
+=
batch_size
reqs
=
[
Request
(
type
=
rtype
,
sender_id
=
self
.
sender_id
,
req_id
=
req_id
,
data
=
rdata
)
for
req_id
,
rtype
,
rdata
in
zip
(
req_ids
,
req_types
,
data
)
]
return
req_ids
,
reqs
async
def
async_batched_send_async
(
self
,
req_types
:
List
[
RequestType
],
data
:
List
[
Any
]):
"""Batched send request asynchronize."""
req_ids
,
reqs
=
self
.
_gather_request
(
req_types
,
data
)
await
self
.
_async_req_put
(
reqs
)
return
req_ids
async
def
async_send_async
(
self
,
req_type
:
RequestType
,
data
:
Any
):
"""send request asynchronize."""
return
(
await
self
.
async_batched_send_async
(
req_types
=
[
req_type
],
data
=
[
data
]))[
0
]
def
batched_send_async
(
self
,
req_types
:
List
[
RequestType
],
data
:
List
[
Any
])
->
List
[
int
]:
"""Batched send request asynchronize.
Different behavior in threadsafe mode.
"""
if
not
self
.
is_thread_safe
():
coro
=
self
.
async_batched_send_async
(
req_types
,
data
)
return
self
.
run_until_complete
(
coro
)
req_ids
,
reqs
=
self
.
_gather_request
(
req_types
,
data
)
self
.
_req_put
(
reqs
)
return
req_ids
def
send_async
(
self
,
req_type
:
RequestType
,
data
:
Any
)
->
int
:
"""send request asynchronize."""
return
self
.
batched_send_async
(
req_types
=
[
req_type
],
data
=
[
data
])[
0
]
async
def
async_recv_any
(
self
,
que_timeout
:
float
=
None
)
->
Response
:
"""receive any response."""
self
.
_prefetch_resps
()
for
req_id
in
self
.
resp_dict
:
ret
=
self
.
_pop_resp
(
req_id
,
default
=
None
)
if
ret
is
not
None
:
return
ret
return
await
self
.
_async_resp_get
()
def
recv_any
(
self
,
que_timeout
:
float
=
None
)
->
Response
:
"""receive any response."""
coro
=
self
.
async_recv_any
(
que_timeout
)
return
self
.
run_until_complete
(
coro
)
def
recv_all
(
self
,
req_id
:
int
,
block
:
bool
=
True
):
"""revceive all response with req_id."""
self
.
_prefetch_resps
()
resps
=
self
.
resp_dict
.
pop
(
req_id
,
[])
return
resps
async
def
async_recv
(
self
,
req_id
:
int
,
que_timeout
:
float
=
None
)
->
Response
:
"""receive response of given request id async."""
ret
=
self
.
_pop_resp
(
req_id
,
default
=
None
)
if
ret
is
not
None
:
return
ret
# check resp que
while
True
:
resp
:
Response
=
await
self
.
_async_resp_get
()
if
resp
.
req_id
!=
req_id
:
self
.
_push_resp
(
req_id
,
resp
)
else
:
return
resp
def
recv
(
self
,
req_id
:
int
,
que_timeout
:
float
=
None
)
->
Response
:
"""receive response of given request id.
Different behavior in threadsafe mode.
"""
if
not
self
.
is_thread_safe
():
coro
=
self
.
async_recv
(
req_id
,
que_timeout
)
return
self
.
run_until_complete
(
coro
)
ret
=
self
.
_pop_resp
(
req_id
,
default
=
None
)
if
ret
is
not
None
:
return
ret
# check resp que
while
True
:
resp
:
Response
=
self
.
_resp_get
()
if
resp
.
req_id
!=
req_id
:
self
.
_push_resp
(
req_id
,
resp
)
else
:
return
resp
async
def
async_send
(
self
,
req_type
:
RequestType
,
data
:
Any
,
que_timeout
:
float
=
None
):
"""send and receive synchronize."""
req_id
=
await
self
.
async_send_async
(
req_type
,
data
)
return
await
self
.
async_recv
(
req_id
,
que_timeout
=
que_timeout
)
def
send
(
self
,
req_type
:
RequestType
,
data
:
Any
,
que_timeout
:
float
=
None
)
->
Response
:
"""send and receive synchronize."""
req_id
=
self
.
send_async
(
req_type
,
data
)
return
self
.
recv
(
req_id
,
que_timeout
=
que_timeout
)
def
response_callback
(
self
,
resp
:
Response
):
"""response callback."""
self
.
resp_que
.
put_nowait
(
resp
)
class
RequestManager
:
"""Request manager."""
def
__init__
(
self
,
thread_safe
:
bool
=
False
):
self
.
senders
:
Dict
[
int
,
RequestSender
]
=
dict
()
self
.
callbacks
:
Dict
[
RequestType
,
Callable
]
=
dict
()
self
.
request_priority
:
List
[
RequestType
]
=
[
RequestType
.
STOP_ENGINE
,
RequestType
.
STOP_SESSION
,
RequestType
.
END_SESSION
,
RequestType
.
ADD_SESSION
,
RequestType
.
ADD_MESSAGE
]
self
.
requests
:
asyncio
.
Queue
=
None
self
.
_loop_task
:
asyncio
.
Future
=
None
self
.
_loop_coro
:
Callable
=
None
self
.
_thread_safe
=
thread_safe
self
.
_next_sender_id
=
0
self
.
_mutex
=
Lock
()
self
.
_loop_thread
:
Thread
=
None
self
.
thread_requests
:
Queue
=
None
# every sender has it's own responses, this responses is
# only used in thread safe mode.
self
.
responses
:
asyncio
.
Queue
=
None
if
thread_safe
:
self
.
thread_requests
=
Queue
()
def
create_loop_task
(
self
):
"""create coro task."""
logger
.
debug
(
'creating engine loop task.'
)
event_loop
=
asyncio
.
get_event_loop
()
assert
self
.
_loop_coro
is
not
None
,
(
'Please set loop task with manager.start_loop'
)
loop_unshielded
=
event_loop
.
create_task
(
self
.
_loop_coro
(),
name
=
'EngineMainLoop'
)
loop_unshielded
.
add_done_callback
(
_raise_exception_on_finish
)
self
.
_loop_task
=
asyncio
.
shield
(
loop_unshielded
)
self
.
requests
=
asyncio
.
Queue
()
return
self
.
_loop_task
@
property
def
event_loop
(
self
):
"""get event loop."""
if
self
.
_loop_task
is
None
:
return
None
else
:
return
self
.
_loop_task
.
get_loop
()
def
is_thread_safe
(
self
):
"""is thread safe."""
return
self
.
_thread_safe
def
start_loop
(
self
,
loop
:
asyncio
.
Task
):
"""start main loop."""
self
.
_loop_coro
=
loop
def
__get_thread_reqs
():
"""get thread reqs."""
num_reqs
=
self
.
thread_requests
.
qsize
()
reqs
=
[]
for
_
in
range
(
num_reqs
):
tmp_reqs
=
self
.
thread_requests
.
get_nowait
()
if
isinstance
(
tmp_reqs
,
Request
):
tmp_reqs
=
[
tmp_reqs
]
reqs
+=
tmp_reqs
return
reqs
async
def
__req_loop
():
"""req loop."""
while
True
:
# get reqs
reqs
=
__get_thread_reqs
()
if
len
(
reqs
)
>
0
:
await
self
.
requests
.
put
(
reqs
)
else
:
await
asyncio
.
sleep
(
0.02
)
def
__put_thread_resps
(
resps
:
List
[
Response
]):
"""put thread resps."""
for
resp
in
resps
:
sender
=
self
.
senders
.
get
(
resp
.
sender_id
,
None
)
if
sender
is
None
:
continue
sender
.
resp_thread_que
.
put_nowait
(
resp
)
async
def
__resp_loop
():
"""resp loop."""
while
True
:
num_resps
=
self
.
responses
.
qsize
()
resps
=
[]
for
_
in
range
(
num_resps
):
resps
.
append
(
self
.
responses
.
get_nowait
())
if
len
(
resps
)
>
0
:
__put_thread_resps
(
resps
)
else
:
await
asyncio
.
sleep
(
0.02
)
def
__run_forever
(
event_loop
:
asyncio
.
BaseEventLoop
):
"""run forever."""
logger
.
debug
(
'start thread run forever.'
)
asyncio
.
set_event_loop
(
event_loop
)
self
.
create_loop_task
()
req_loop
=
event_loop
.
create_task
(
__req_loop
(),
name
=
'RunForeverReqLoop'
)
req_loop
.
add_done_callback
(
_ignore_exception_on_finish
)
resp_loop
=
event_loop
.
create_task
(
__resp_loop
(),
name
=
'RunForeverRespLoop'
)
resp_loop
.
add_done_callback
(
_ignore_exception_on_finish
)
self
.
event_loop
.
run_forever
()
if
self
.
is_thread_safe
():
event_loop
=
asyncio
.
new_event_loop
()
self
.
responses
=
asyncio
.
Queue
()
self
.
_loop_thread
=
Thread
(
target
=
__run_forever
,
args
=
(
event_loop
,
),
daemon
=
True
)
self
.
_loop_thread
.
start
()
def
is_loop_alive
(
self
):
"""check if main loop is alive."""
def
__check_threadsafe
():
if
self
.
_loop_thread
is
None
:
return
False
if
not
self
.
_loop_thread
.
is_alive
():
return
False
if
self
.
_loop_task
is
None
:
return
False
return
not
self
.
_loop_task
.
done
()
if
self
.
is_thread_safe
():
return
__check_threadsafe
()
if
self
.
_loop_task
is
None
:
logger
.
debug
(
'loop task has not been created.'
)
return
False
if
self
.
_loop_task
.
get_loop
()
!=
asyncio
.
get_event_loop
():
logger
.
warning
(
'Current event loop is different from'
' the one bound to loop task!'
)
return
False
return
not
self
.
_loop_task
.
done
()
def
build_sender
(
self
):
"""create a new sender."""
with
self
.
_mutex
:
sender_id
=
self
.
_next_sender_id
self
.
_next_sender_id
+=
1
new_sender
=
RequestSender
.
new
(
sender_id
,
self
)
self
.
senders
[
sender_id
]
=
new_sender
return
new_sender
def
has_requests
(
self
):
"""has unprocessed request."""
if
self
.
requests
is
None
:
return
False
return
not
self
.
requests
.
empty
()
def
get_all_requests
(
self
)
->
Dict
[
RequestType
,
Request
]:
"""get all requests in current queue."""
num_reqs
=
self
.
requests
.
qsize
()
reqs
:
ReqList
=
[]
for
_
in
range
(
num_reqs
):
elem
=
self
.
requests
.
get_nowait
()
if
isinstance
(
elem
,
Request
):
elem
=
[
elem
]
reqs
+=
elem
# gather requests
reqs_by_type
:
Dict
[
RequestType
,
Request
]
=
dict
(
(
t
,
[])
for
t
in
RequestType
)
for
req
in
reqs
:
reqs_by_type
[
req
.
type
].
append
(
req
)
return
reqs_by_type
def
bind_func
(
self
,
req_type
:
RequestType
,
callback
:
Callable
):
"""bind handler for given request type."""
self
.
callbacks
[
req_type
]
=
callback
def
set_request_priority
(
self
,
priority
:
List
[
RequestType
]):
"""set the priority of request type."""
self
.
request_priority
=
priority
def
response
(
self
,
resp
:
Response
):
"""send response."""
if
resp
.
sender_id
not
in
self
.
senders
:
logger
.
warning
(
f
'sender
{
resp
.
sender_id
}
not exist. '
f
'Send
{
resp
}
failed.'
)
return
self
.
senders
[
resp
.
sender_id
].
response_callback
(
resp
)
def
process_request
(
self
,
req_type
:
RequestType
,
reqs
:
ReqList
,
**
kwargs
):
"""process reqs with given req type."""
# get callback
func
=
self
.
callbacks
.
get
(
req_type
,
None
)
if
func
is
not
None
:
func
(
reqs
,
**
kwargs
)
else
:
# TODO: send error message
for
req
in
reqs
:
resp
=
Response
(
ResponseType
.
HANDLER_NOT_EXIST
,
sender_id
=
req
.
sender_id
,
req_id
=
req
.
req_id
,
err_msg
=
(
f
'callback for
{
req_type
}
'
' not exists.'
))
self
.
response
(
resp
)
def
step
(
self
,
**
kwargs
):
"""handle requests.
Should only be called in loop task.
"""
reqs_by_type
=
self
.
get_all_requests
()
# handle requests
for
req_type
in
self
.
request_priority
:
# request exists
if
req_type
not
in
reqs_by_type
or
len
(
reqs_by_type
)
==
0
:
continue
reqs
:
ReqList
=
reqs_by_type
[
req_type
]
self
.
process_request
(
req_type
,
reqs
,
**
kwargs
)
def
run_until_complete
(
self
,
future
:
Awaitable
):
"""run untile complete."""
return
_run_until_complete
(
future
)
lmdeploy/pytorch/kernels/__init__.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
from
.alibi_pagedattention
import
alibi_paged_attention_fwd
from
.apply_rotary_pos_emb
import
apply_rotary_pos_emb
from
.fill_kv_cache
import
fill_kv_cache
from
.fused_rotary_emb
import
fused_rotary_emb
from
.multinomial_sampling
import
multinomial_sampling
from
.pagedattention
import
paged_attention_fwd
from
.rerope_attention
import
rerope_attention_fwd
from
.rms_norm
import
rms_norm
__all__
=
[
'apply_rotary_pos_emb'
,
'fused_rotary_emb'
,
'paged_attention_fwd'
,
'alibi_paged_attention_fwd'
,
'fill_kv_cache'
,
'multinomial_sampling'
,
'rms_norm'
,
'rerope_attention_fwd'
]
lmdeploy/pytorch/kernels/alibi_pagedattention.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/ModelTC/lightllm
import
math
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
assert
triton
.
__version__
>=
'2.1.0'
LOG2
=
math
.
log
(
2
)
@
triton
.
jit
def
tl_pow
(
a
,
b
):
"""triton pow."""
return
tl
.
exp
(
b
*
tl
.
log
(
a
))
@
triton
.
jit
def
tl_2pow
(
b
):
"""triton pow2."""
return
tl
.
exp
(
b
*
LOG2
)
@
triton
.
jit
def
tl_log2
(
a
):
"""triton log2."""
return
tl
.
log
(
a
)
/
LOG2
@
triton
.
jit
def
_get_interleave_power_of_2
(
i
,
n
):
"""get interleave power of 2."""
start
=
-
tl_2pow
(
3
-
tl_log2
(
n
))
start
=
tl_2pow
(
start
)
ratio
=
start
return
start
*
tl_pow
(
ratio
,
i
)
@
triton
.
jit
def
get_slope
(
i
,
n
):
"""get slope."""
closest_power_of_2
=
tl_2pow
(
tl_log2
(
n
).
to
(
tl
.
int32
))
if
i
<
closest_power_of_2
:
return
_get_interleave_power_of_2
(
i
,
closest_power_of_2
)
else
:
return
_get_interleave_power_of_2
((
i
-
closest_power_of_2
)
*
2
,
2
*
closest_power_of_2
)
@
triton
.
jit
def
_load_block_offsets
(
offset_ptr
,
block_id
,
num_sub_blocks
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
):
if
num_sub_blocks
>
1
:
offs_sub
=
tl
.
arange
(
0
,
num_sub_blocks
)
offs_n
=
tl
.
arange
(
0
,
BLOCK
//
num_sub_blocks
)
ret
=
tl
.
load
(
offset_ptr
+
block_id
*
num_sub_blocks
+
offs_sub
)[
None
,
:]
*
BLOCK
//
num_sub_blocks
+
offs_n
[:,
None
]
return
tl
.
ravel
(
ret
)
else
:
offs_n
=
tl
.
arange
(
0
,
BLOCK
)
return
tl
.
load
(
offset_ptr
+
block_id
)
*
BLOCK
+
offs_n
@
triton
.
jit
def
_fwd_split_kernel
(
Q
,
K
,
V
,
sm_scale
,
alibi_scale
,
B_kvlen
,
Block_offsets
,
Acc_out
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_ok
,
stride_obs
,
stride_oh
,
stride_od
,
stride_boffb
,
head_offset
,
num_heads
,
kv_group_num
,
block_per_cta
,
num_sub_blocks
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""first step kernel of split k attention."""
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
split_k_id
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
1
cur_batch_kv_len
=
tl
.
load
(
B_kvlen
+
cur_batch
)
history_len
=
cur_batch_kv_len
-
cur_batch_seq_len
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
off_q
=
(
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
*
stride_qd
)
off_k
=
(
cur_kv_head
*
stride_kh
+
offs_d
[
None
,
:]
*
stride_kd
)
off_v
=
(
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
q
=
tl
.
load
(
Q
+
off_q
).
to
(
tl
.
float32
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_offset_ptrs
=
Block_offsets
+
cur_batch
*
stride_boffb
head_slope
=
get_slope
(
cur_head
.
to
(
tl
.
float32
)
+
head_offset
,
num_heads
.
to
(
tl
.
float32
))
# initialize pointer to m and l
m_i
=
-
float
(
'inf'
)
l_i
=
float
(
0
)
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
kv_len_per_prog
=
block_per_cta
*
BLOCK_N
loop_start
=
kv_len_per_prog
*
split_k_id
loop_end
=
tl
.
minimum
(
loop_start
+
kv_len_per_prog
,
cur_batch_kv_len
)
# load block offset
start_block_id
=
loop_start
//
BLOCK_N
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
start_block_id
,
num_sub_blocks
,
BLOCK_N
)
for
start_n
in
range
(
loop_start
,
loop_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_kv_len
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
b_offset
[:,
None
]
*
stride_kbs
,
mask
=
mask
,
other
=
0.0
,
)
v
=
tl
.
load
(
v_ptrs
+
b_offset
[:,
None
]
*
stride_vbs
,
mask
=
mask
,
other
=
0.0
,
)
# prefetch b_offset
if
start_n
+
BLOCK_N
<
loop_end
:
start_block_id
+=
1
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
start_block_id
,
num_sub_blocks
,
BLOCK_N
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
*=
sm_scale
mask
=
start_n
+
offs_n
bias
=
mask
.
to
(
tl
.
float32
)
*
(
head_slope
*
alibi_scale
)
qk
+=
bias
# NOTE: inf - inf = nan, and nan will leads to error
qk
=
tl
.
where
(
history_len
>=
(
start_n
+
offs_n
),
qk
,
-
float
(
'inf'
),
)
# -- compute p, m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
0
))
p
=
tl
.
exp
(
qk
-
m_i_new
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
tl
.
sum
(
p
,
0
)
# -- update output accumulator --
# scale acc
acc
=
acc
*
alpha
# update acc
p_new
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
sum
(
p_new
[:,
None
]
*
v
,
0
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_acc
=
(
cur_batch
*
stride_obs
+
split_k_id
*
stride_ok
+
cur_head
*
stride_oh
+
offs_d
*
stride_od
)
tl
.
store
(
Acc_out
+
off_acc
,
acc
)
off_meta
=
(
cur_batch
*
stride_obs
+
split_k_id
*
stride_ok
+
cur_head
*
stride_oh
+
BLOCK_DMODEL
)
tl
.
store
(
Acc_out
+
off_meta
+
tl
.
arange
(
0
,
1
),
m_i
)
tl
.
store
(
Acc_out
+
off_meta
+
1
+
tl
.
arange
(
0
,
1
),
l_i
)
@
triton
.
jit
def
_reduce_split_kernel
(
Acc
,
Out
,
stride_ak
,
stride_abs
,
stride_ah
,
stride_ad
,
stride_obs
,
stride_oh
,
stride_od
,
SPLIT_K
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
"""second step kernel of split k attention."""
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
# initialize offsets
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_k
=
tl
.
arange
(
0
,
SPLIT_K
)
offs_acc
=
(
cur_batch
*
stride_abs
+
cur_head
*
stride_ah
+
offs_k
[:,
None
]
*
stride_ak
+
offs_d
[
None
,
:]
*
stride_ad
)
offs_mi
=
(
cur_batch
*
stride_abs
+
cur_head
*
stride_ah
+
stride_ak
*
offs_k
+
BLOCK_DMODEL
)
acc_k
=
tl
.
load
(
Acc
+
offs_acc
)
m_k
=
tl
.
load
(
Acc
+
offs_mi
)
l_k
=
tl
.
load
(
Acc
+
offs_mi
+
1
)
m_max
=
tl
.
max
(
m_k
,
0
)
alpha
=
tl
.
exp
(
m_k
-
m_max
)
acc_k
=
acc_k
*
alpha
[:,
None
]
l_k
=
l_k
*
alpha
acc
=
tl
.
sum
(
acc_k
,
0
)
l_sum
=
tl
.
sum
(
l_k
,
0
)
acc
=
acc
/
l_sum
out_offs
=
(
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
*
stride_od
)
tl
.
store
(
Out
+
out_offs
,
acc
)
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
alibi_scale
,
B_Start_Loc
,
B_Seqlen
,
B_kvlen
,
Block_offsets
,
Out
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_obs
,
stride_oh
,
stride_od
,
stride_boffb
,
head_offset
,
num_heads
,
kv_group_num
,
num_sub_blocks
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""forward kernel."""
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_kv_len
=
tl
.
load
(
B_kvlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
history_len
=
cur_batch_kv_len
-
cur_batch_seq_len
block_start_loc
=
BLOCK_M
*
start_m
head_slope
=
get_slope
(
cur_head
.
to
(
tl
.
float32
)
+
head_offset
,
num_heads
.
to
(
tl
.
float32
))
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
((
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
off_k
=
(
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
,
other
=
0.0
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_offset_ptrs
=
Block_offsets
+
cur_batch
*
stride_boffb
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'inf'
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
,
1
,
0
)
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
0
,
num_sub_blocks
,
BLOCK_N
)
for
start_n
in
range
(
0
,
block_mask
*
cur_batch_kv_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
b_offset
[
None
,
:]
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_kv_len
,
other
=
0.0
,
)
v
=
tl
.
load
(
v_ptrs
+
b_offset
[:,
None
]
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_kv_len
,
other
=
0.0
,
)
if
start_n
+
BLOCK_N
<
cur_batch_kv_len
:
start_block_id
=
start_n
//
BLOCK_N
+
1
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
start_block_id
,
num_sub_blocks
,
BLOCK_N
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
mask
=
start_n
+
offs_n
[
None
,
:]
bias
=
mask
.
to
(
tl
.
float32
)
*
(
head_slope
*
alibi_scale
)
qk
+=
bias
# NOTE: inf - inf = nan, and nan will leads to error
qk
=
tl
.
where
(
(
history_len
+
offs_m
[:,
None
])
>=
mask
,
qk
,
float
(
-
1e30
),
)
# -- compute p, m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
p
=
tl
.
exp
(
qk
-
m_i_new
[:,
None
])
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
tl
.
sum
(
p
,
1
)
# -- update output accumulator --
# scale acc
acc
=
acc
*
alpha
[:,
None
]
# update acc
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
acc
=
acc
/
l_i
[:,
None
]
# initialize pointers to output
off_o
=
((
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
)
@
torch
.
no_grad
()
def
alibi_paged_attention_fwd
(
q
:
Tensor
,
k
:
Tensor
,
v
:
Tensor
,
o
:
Tensor
,
block_offsets
:
Tensor
,
b_start_loc
:
Tensor
,
b_seq_len
:
Tensor
,
b_kv_seq_len
:
Tensor
,
max_input_len
:
int
,
head_offset
:
int
=
0
,
num_heads
:
int
=
-
1
,
alibi_scale
:
float
=
1.0
):
"""Paged attention forward with alibi bias.
Args:
q (Tensor): Query state.
k (Tensor): Key state caches.
v (Tensor): Value state caches.
o (Tensor): Output state.
block_offsets (Tensor): The block offset of key and value.
b_start_loc (Tensor): Start token location of each data in batch.
b_seq_len (Tensor): Query length for each data in batch.
b_kv_seq_len (Tensor): Key/Value length for each data in batch.
max_input_len (int): The max input length.
head_offset (int): The offset of the start head. Head might be
partitioned when tensor parallel inference.
num_heads (int): The number of heads. Head might be partitioned when
tensor parallel inference.
BLOCK (int): The kernel block size.
"""
def
_kernel_meta
():
device
=
q
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
sm_scale
=
1.0
/
(
Lq
**
0.5
)
# 计算scale系数
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
-
2
]
//
k
[
0
].
shape
[
-
2
]
if
num_heads
<=
0
:
num_heads
=
head
BLOCK
=
64
if
k
.
size
(
1
)
<
16
else
k
.
size
(
1
)
num_sub_blocks
=
BLOCK
//
k
.
size
(
1
)
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
num_warps
=
4
if
Lk
<=
64
else
8
kernel_meta
=
_kernel_meta
()
is_decoding
=
q
.
shape
[
-
3
]
==
b_seq_len
.
size
(
0
)
if
not
is_decoding
:
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
alibi_scale
,
b_start_loc
,
b_seq_len
,
b_kv_seq_len
,
block_offsets
,
o
,
q
.
stride
(
-
3
),
q
.
stride
(
-
2
),
q
.
stride
(
-
1
),
k
.
stride
(
-
3
),
k
.
stride
(
-
2
),
k
.
stride
(
-
1
),
v
.
stride
(
-
3
),
v
.
stride
(
-
2
),
v
.
stride
(
-
1
),
o
.
stride
(
-
3
),
o
.
stride
(
-
2
),
o
.
stride
(
-
1
),
block_offsets
.
stride
(
0
),
head_offset
=
head_offset
,
num_heads
=
num_heads
,
kv_group_num
=
kv_group_num
,
num_sub_blocks
=
num_sub_blocks
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
else
:
SPLIT_K
=
4
grid
=
(
batch
,
head
,
SPLIT_K
)
block_per_cta
=
triton
.
cdiv
(
block_offsets
.
size
(
-
1
),
SPLIT_K
)
acc
=
q
.
new_empty
(
batch
,
head
,
SPLIT_K
,
Lq
+
2
,
dtype
=
torch
.
float32
)
_fwd_split_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
alibi_scale
,
b_kv_seq_len
,
block_offsets
,
acc
,
stride_qbs
=
q
.
stride
(
-
3
),
stride_qh
=
q
.
stride
(
-
2
),
stride_qd
=
q
.
stride
(
-
1
),
stride_kbs
=
k
.
stride
(
-
3
),
stride_kh
=
k
.
stride
(
-
2
),
stride_kd
=
k
.
stride
(
-
1
),
stride_vbs
=
v
.
stride
(
-
3
),
stride_vh
=
v
.
stride
(
-
2
),
stride_vd
=
v
.
stride
(
-
1
),
stride_ok
=
acc
.
stride
(
-
2
),
stride_obs
=
acc
.
stride
(
-
4
),
stride_oh
=
acc
.
stride
(
-
3
),
stride_od
=
acc
.
stride
(
-
1
),
stride_boffb
=
block_offsets
.
stride
(
0
),
head_offset
=
head_offset
,
num_heads
=
num_heads
,
kv_group_num
=
kv_group_num
,
block_per_cta
=
block_per_cta
,
num_sub_blocks
=
num_sub_blocks
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
num_warps
=
4
,
num_stages
=
1
,
**
kernel_meta
)
grid
=
(
batch
,
head
)
_reduce_split_kernel
[
grid
](
acc
,
o
,
stride_ak
=
acc
.
stride
(
-
2
),
stride_abs
=
acc
.
stride
(
-
4
),
stride_ah
=
acc
.
stride
(
-
3
),
stride_ad
=
acc
.
stride
(
-
1
),
stride_obs
=
o
.
stride
(
-
3
),
stride_oh
=
o
.
stride
(
-
2
),
stride_od
=
o
.
stride
(
-
1
),
SPLIT_K
=
SPLIT_K
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
lmdeploy/pytorch/kernels/apply_rotary_pos_emb.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
@
triton
.
jit
def
apply_rotary_pos_emb_kernel
(
Q
,
COS
,
SIN
,
POS
,
Q_EMB
,
seq_len
,
stride_qh
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""apply rotary on key OR query kernel."""
seq_block_id
=
tl
.
program_id
(
0
)
head_id
=
tl
.
program_id
(
1
)
pos_offset
=
seq_block_id
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
pos_ids
=
tl
.
load
(
POS
+
pos_offset
,
pos_offset
<
seq_len
,
other
=-
1
)
feat_size
=
BLOCK_N
*
2
feat_offset_l
=
tl
.
arange
(
0
,
BLOCK_N
)
feat_offset_h
=
BLOCK_N
+
feat_offset_l
cs_offset_l
=
pos_ids
[:,
None
]
*
feat_size
+
feat_offset_l
[
None
,
:]
cs_offset_h
=
pos_ids
[:,
None
]
*
feat_size
+
feat_offset_h
[
None
,
:]
pos_ids_mask
=
pos_ids
[:,
None
]
>=
0
cos_l
=
tl
.
load
(
COS
+
cs_offset_l
,
mask
=
pos_ids_mask
)
cos_h
=
tl
.
load
(
COS
+
cs_offset_h
,
mask
=
pos_ids_mask
)
sin_l
=
tl
.
load
(
SIN
+
cs_offset_l
,
mask
=
pos_ids_mask
)
sin_h
=
tl
.
load
(
SIN
+
cs_offset_h
,
mask
=
pos_ids_mask
)
q_offset_seq
=
pos_offset
[:,
None
]
*
stride_qh
+
head_id
*
feat_size
q_offset_l
=
q_offset_seq
+
feat_offset_l
[
None
,
:]
q_offset_h
=
q_offset_seq
+
feat_offset_h
[
None
,
:]
pos_mask
=
pos_offset
[:,
None
]
<
seq_len
q_l
=
tl
.
load
(
Q
+
q_offset_l
,
mask
=
pos_mask
)
q_h
=
tl
.
load
(
Q
+
q_offset_h
,
mask
=
pos_mask
)
q_emb_l
=
q_l
*
cos_l
-
q_h
*
sin_l
q_emb_h
=
q_h
*
cos_h
+
q_l
*
sin_h
tl
.
store
(
Q_EMB
+
q_offset_l
,
q_emb_l
,
mask
=
pos_mask
)
tl
.
store
(
Q_EMB
+
q_offset_h
,
q_emb_h
,
mask
=
pos_mask
)
@
triton
.
jit
def
apply_rotary_pos_emb_qk_kernel
(
Q
,
K
,
COS
,
SIN
,
POS
,
Q_EMB
,
K_EMB
,
seq_len
,
stride_qh
:
tl
.
constexpr
,
stride_kh
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""apply rotary on key AND query kernel."""
seq_block_id
=
tl
.
program_id
(
0
)
head_id
=
tl
.
program_id
(
1
)
pos_offset
=
seq_block_id
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
pos_ids
=
tl
.
load
(
POS
+
pos_offset
,
pos_offset
<
seq_len
,
other
=-
1
)
feat_size
=
BLOCK_N
*
2
feat_offset_l
=
tl
.
arange
(
0
,
BLOCK_N
)
feat_offset_h
=
BLOCK_N
+
feat_offset_l
cs_offset_l
=
pos_ids
[:,
None
]
*
feat_size
+
feat_offset_l
[
None
,
:]
cs_offset_h
=
pos_ids
[:,
None
]
*
feat_size
+
feat_offset_h
[
None
,
:]
pos_ids_mask
=
pos_ids
[:,
None
]
>=
0
cos_l
=
tl
.
load
(
COS
+
cs_offset_l
,
mask
=
pos_ids_mask
)
cos_h
=
tl
.
load
(
COS
+
cs_offset_h
,
mask
=
pos_ids_mask
)
sin_l
=
tl
.
load
(
SIN
+
cs_offset_l
,
mask
=
pos_ids_mask
)
sin_h
=
tl
.
load
(
SIN
+
cs_offset_h
,
mask
=
pos_ids_mask
)
q_offset_seq
=
pos_offset
[:,
None
]
*
stride_qh
+
head_id
*
feat_size
q_offset_l
=
q_offset_seq
+
feat_offset_l
[
None
,
:]
q_offset_h
=
q_offset_seq
+
feat_offset_h
[
None
,
:]
k_offset_seq
=
pos_offset
[:,
None
]
*
stride_kh
+
head_id
*
feat_size
k_offset_l
=
k_offset_seq
+
feat_offset_l
[
None
,
:]
k_offset_h
=
k_offset_seq
+
feat_offset_h
[
None
,
:]
pos_mask
=
pos_offset
[:,
None
]
<
seq_len
q_l
=
tl
.
load
(
Q
+
q_offset_l
,
mask
=
pos_mask
)
q_h
=
tl
.
load
(
Q
+
q_offset_h
,
mask
=
pos_mask
)
k_l
=
tl
.
load
(
K
+
k_offset_l
,
mask
=
pos_mask
)
k_h
=
tl
.
load
(
K
+
k_offset_h
,
mask
=
pos_mask
)
q_emb_l
=
q_l
*
cos_l
-
q_h
*
sin_l
q_emb_h
=
q_h
*
cos_h
+
q_l
*
sin_h
k_emb_l
=
k_l
*
cos_l
-
k_h
*
sin_l
k_emb_h
=
k_h
*
cos_h
+
k_l
*
sin_h
tl
.
store
(
Q_EMB
+
q_offset_l
,
q_emb_l
,
mask
=
pos_mask
)
tl
.
store
(
Q_EMB
+
q_offset_h
,
q_emb_h
,
mask
=
pos_mask
)
tl
.
store
(
K_EMB
+
k_offset_l
,
k_emb_l
,
mask
=
pos_mask
)
tl
.
store
(
K_EMB
+
k_offset_h
,
k_emb_h
,
mask
=
pos_mask
)
@
torch
.
inference_mode
()
def
apply_rotary_pos_emb
(
q
:
Tensor
,
k
:
Tensor
,
cos
:
Tensor
,
sin
:
Tensor
,
position_ids
:
Tensor
,
position_ids_1d
:
Tensor
=
None
,
q_embed
:
Tensor
=
None
,
k_embed
:
Tensor
=
None
):
"""Apply rotary positional embedding on query and key.
Args:
q (Tensor): Query state.
k (Tensor): Key state.
cos (Tensor): cosine matrix (seq_len, dim).
sin (Tensor): sine matrix (seq_len, dim).
position_ids (Tensor): Position ids of q and k.
position_ids_1d (Tensor): 1d Position ids.
q_embed (Tensor): output q, can be same as q
k_embed (Tensor): output k, can be same as k
Returns:
Tuple[Tensor, Tensor]: Embedded query and key.
"""
if
not
q
.
is_contiguous
():
q
=
q
.
contiguous
()
if
not
k
.
is_contiguous
():
k
=
k
.
contiguous
()
if
cos
.
device
!=
q
.
device
or
cos
.
dtype
!=
q
.
dtype
:
cos
=
cos
.
to
(
device
=
q
.
device
,
dtype
=
q
.
dtype
)
if
sin
.
device
!=
q
.
device
or
sin
.
dtype
!=
q
.
dtype
:
sin
=
sin
.
to
(
device
=
q
.
device
,
dtype
=
q
.
dtype
)
if
position_ids_1d
is
None
:
seq_length
=
position_ids
[...,
-
1
]
+
1
position_ids_1d
=
[
ids
[:
l
]
for
ids
,
l
in
zip
(
position_ids
,
seq_length
)]
position_ids_1d
=
torch
.
cat
(
position_ids_1d
)
if
q_embed
is
None
:
q_embed
=
torch
.
empty_like
(
q
)
if
k_embed
is
None
:
k_embed
=
torch
.
empty_like
(
k
)
seq_len
=
position_ids_1d
.
size
(
-
1
)
BLOCK
=
32
num_heads_q
=
q
.
size
(
-
2
)
num_heads_k
=
k
.
size
(
-
2
)
num_warps
=
4
num_stages
=
2
device
=
q
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
if
num_heads_k
==
num_heads_q
:
grid
=
[
triton
.
cdiv
(
seq_len
,
BLOCK
),
num_heads_q
]
apply_rotary_pos_emb_qk_kernel
[
grid
](
q
,
k
,
cos
,
sin
,
position_ids_1d
,
q_embed
,
k_embed
,
seq_len
=
seq_len
,
stride_qh
=
q
.
stride
(
-
3
),
stride_kh
=
k
.
stride
(
-
3
),
BLOCK
=
BLOCK
,
BLOCK_N
=
q
.
size
(
-
1
)
//
2
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
stream
=
stream
,
device
=
device_idx
,
device_type
=
device_type
)
else
:
grid_q
=
[
triton
.
cdiv
(
seq_len
,
BLOCK
),
num_heads_q
]
grid_k
=
[
triton
.
cdiv
(
seq_len
,
BLOCK
),
num_heads_k
]
apply_rotary_pos_emb_kernel
[
grid_q
](
q
,
cos
,
sin
,
position_ids_1d
,
q_embed
,
seq_len
=
seq_len
,
stride_qh
=
q
.
stride
(
-
3
),
BLOCK
=
BLOCK
,
BLOCK_N
=
q
.
size
(
-
1
)
//
2
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
stream
=
stream
,
device
=
device_idx
,
device_type
=
device_type
)
apply_rotary_pos_emb_kernel
[
grid_k
](
k
,
cos
,
sin
,
position_ids_1d
,
k_embed
,
seq_len
=
seq_len
,
stride_qh
=
k
.
stride
(
-
3
),
BLOCK
=
BLOCK
,
BLOCK_N
=
k
.
size
(
-
1
)
//
2
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
stream
=
stream
,
device
=
device_idx
,
device_type
=
device_type
)
return
q_embed
,
k_embed
lmdeploy/pytorch/kernels/fill_kv_cache.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
@
triton
.
jit
def
_div_up
(
val
,
other
):
return
(
val
+
other
-
1
)
//
other
@
triton
.
jit
def
_fill_kv_cache_kernel
(
KStates
,
VStates
,
KCaches
,
VCaches
,
QStartLoc
,
QSeqLens
,
KVSeqLens
,
BlockOffsets
,
num_heads
:
tl
.
constexpr
,
head_dim
:
tl
.
constexpr
,
stride_kss
,
stride_ksh
,
stride_ksd
,
stride_vss
,
stride_vsh
,
stride_vsd
,
stride_kcn
:
tl
.
constexpr
,
stride_kcb
:
tl
.
constexpr
,
stride_kch
:
tl
.
constexpr
,
stride_kcd
:
tl
.
constexpr
,
stride_vcn
:
tl
.
constexpr
,
stride_vcb
:
tl
.
constexpr
,
stride_vch
:
tl
.
constexpr
,
stride_vcd
:
tl
.
constexpr
,
stride_boff
,
BLOCK
:
tl
.
constexpr
,
BLOCK_D
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
):
"""fill kv cache kernel."""
batch_id
=
tl
.
program_id
(
0
)
block_id
=
tl
.
program_id
(
1
)
# initialize
h_off
=
tl
.
arange
(
0
,
BLOCK_H
)
d_off
=
tl
.
arange
(
0
,
BLOCK_D
)
q_startloc
=
tl
.
load
(
QStartLoc
+
batch_id
)
q_seqlen
=
tl
.
load
(
QSeqLens
+
batch_id
)
kv_seqlen
=
tl
.
load
(
KVSeqLens
+
batch_id
)
history_seqlen
=
kv_seqlen
-
q_seqlen
block0_first_tokenloc
=
history_seqlen
%
BLOCK
state_token_offset
=
tl
.
maximum
(
block_id
*
BLOCK
-
block0_first_tokenloc
,
0
)
kv_block_id
=
_div_up
(
history_seqlen
+
1
,
BLOCK
)
-
1
+
block_id
kv_block_id
=
min
(
kv_block_id
,
stride_boff
-
1
)
block_off
=
tl
.
load
(
BlockOffsets
+
batch_id
*
stride_boff
+
kv_block_id
)
cur_startloc
=
q_startloc
+
state_token_offset
ks_ptr
=
KStates
+
cur_startloc
*
stride_kss
vs_ptr
=
VStates
+
cur_startloc
*
stride_vss
kc_ptr
=
KCaches
+
block_off
*
stride_kcn
vc_ptr
=
VCaches
+
block_off
*
stride_vcn
c_first_tokenloc
=
block0_first_tokenloc
if
block_id
!=
0
:
c_first_tokenloc
*=
0
c_last_tokenloc
=
tl
.
minimum
(
BLOCK
,
q_seqlen
+
block0_first_tokenloc
-
block_id
*
BLOCK
)
for
bidx
in
range
(
c_first_tokenloc
,
c_last_tokenloc
):
sidx
=
bidx
-
c_first_tokenloc
mask
=
(
h_off
[:,
None
]
<
num_heads
)
&
(
d_off
[
None
,
:]
<
head_dim
)
k
=
tl
.
load
(
ks_ptr
+
sidx
*
stride_kss
+
h_off
[:,
None
]
*
stride_ksh
+
d_off
[
None
,
:]
*
stride_ksd
,
mask
=
mask
)
tl
.
store
(
kc_ptr
+
bidx
*
stride_kcb
+
h_off
[:,
None
]
*
stride_kch
+
d_off
[
None
,
:]
*
stride_kcd
,
k
,
mask
=
mask
)
v
=
tl
.
load
(
vs_ptr
+
sidx
*
stride_vss
+
h_off
[:,
None
]
*
stride_vsh
+
d_off
[
None
,
:]
*
stride_vsd
,
mask
=
mask
)
tl
.
store
(
vc_ptr
+
bidx
*
stride_vcb
+
h_off
[:,
None
]
*
stride_vch
+
d_off
[
None
,
:]
*
stride_vcd
,
v
,
mask
=
mask
)
@
torch
.
inference_mode
()
def
fill_kv_cache
(
k_states
:
Tensor
,
v_states
:
Tensor
,
k_caches
:
Tensor
,
v_caches
:
Tensor
,
q_start_loc
:
Tensor
,
q_seq_length
:
Tensor
,
kv_seq_length
:
Tensor
,
max_q_seq_length
:
int
,
block_offsets
:
Tensor
):
"""fill key/value state to cache for paged attention."""
def
_kernel_meta
():
device
=
k_states
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
block_offsets
=
block_offsets
.
contiguous
()
batch_size
=
block_offsets
.
size
(
0
)
block_size
,
num_heads
,
head_dim
=
k_caches
.
size
()[
1
:]
max_num_blocks
=
triton
.
cdiv
(
max_q_seq_length
,
block_size
)
+
1
BLOCK
=
block_size
BLOCK_H
=
triton
.
next_power_of_2
(
num_heads
)
BLOCK_D
=
triton
.
next_power_of_2
(
head_dim
)
grid
=
[
batch_size
,
max_num_blocks
]
kernel_meta
=
_kernel_meta
()
_fill_kv_cache_kernel
[
grid
](
k_states
,
v_states
,
k_caches
,
v_caches
,
q_start_loc
,
q_seq_length
,
kv_seq_length
,
block_offsets
,
num_heads
=
num_heads
,
head_dim
=
head_dim
,
stride_kss
=
k_states
.
stride
(
-
3
),
stride_ksh
=
k_states
.
stride
(
-
2
),
stride_ksd
=
k_states
.
stride
(
-
1
),
stride_vss
=
v_states
.
stride
(
-
3
),
stride_vsh
=
v_states
.
stride
(
-
2
),
stride_vsd
=
v_states
.
stride
(
-
1
),
stride_kcn
=
k_caches
.
stride
(
0
),
stride_kcb
=
k_caches
.
stride
(
1
),
stride_kch
=
k_caches
.
stride
(
2
),
stride_kcd
=
k_caches
.
stride
(
3
),
stride_vcn
=
v_caches
.
stride
(
0
),
stride_vcb
=
v_caches
.
stride
(
1
),
stride_vch
=
v_caches
.
stride
(
2
),
stride_vcd
=
v_caches
.
stride
(
3
),
stride_boff
=
block_offsets
.
stride
(
0
),
BLOCK
=
BLOCK
,
BLOCK_D
=
BLOCK_D
,
BLOCK_H
=
BLOCK_H
,
**
kernel_meta
,
)
lmdeploy/pytorch/kernels/fused_rotary_emb.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
@
triton
.
jit
def
_fused_rotary_emb_kernel
(
Q
,
K
,
PostionIds
,
InvFreq
,
scaling_factor
,
OutQ
,
OutK
,
stride_bq
,
stride_sq
,
stride_hq
:
tl
.
constexpr
,
stride_dq
:
tl
.
constexpr
,
stride_bk
,
stride_sk
,
stride_hk
:
tl
.
constexpr
,
stride_dk
:
tl
.
constexpr
,
stride_bp
,
stride_sp
,
max_seq_len
,
BLOCK
:
tl
.
constexpr
,
BLOCK_HQ
:
tl
.
constexpr
,
BLOCK_HK
:
tl
.
constexpr
,
BLOCK_F
:
tl
.
constexpr
):
"""fused rotary emb kernel."""
batch_id
=
tl
.
program_id
(
0
)
seq_block_id
=
tl
.
program_id
(
1
)
s_off
=
seq_block_id
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)[:,
None
]
f_off
=
tl
.
arange
(
0
,
BLOCK_F
)[
None
,
:]
s_mask
=
s_off
<
max_seq_len
bp_off
=
stride_bp
*
batch_id
p_off
=
bp_off
+
stride_sp
*
s_off
sq_off
=
batch_id
*
stride_bq
+
s_off
*
stride_sq
q0_off
=
sq_off
+
f_off
*
stride_dq
q1_off
=
q0_off
+
BLOCK_F
*
stride_dq
sk_off
=
batch_id
*
stride_bk
+
s_off
*
stride_sk
k0_off
=
sk_off
+
f_off
*
stride_dk
k1_off
=
k0_off
+
BLOCK_F
*
stride_dk
inv_freq
=
tl
.
load
(
InvFreq
+
f_off
).
to
(
tl
.
float32
)
position_ids
=
tl
.
load
(
PostionIds
+
p_off
,
mask
=
s_mask
).
to
(
tl
.
float32
)
position_ids
=
position_ids
/
scaling_factor
# pos_freq = tl.dot(position_ids, inv_freq)
pos_freq
=
position_ids
*
inv_freq
cos
=
tl
.
cos
(
pos_freq
).
to
(
Q
.
dtype
.
element_ty
)
sin
=
tl
.
sin
(
pos_freq
).
to
(
Q
.
dtype
.
element_ty
)
for
h
in
range
(
BLOCK_HQ
):
q0
=
tl
.
load
(
Q
+
q0_off
+
h
*
stride_hq
,
mask
=
s_mask
)
q1
=
tl
.
load
(
Q
+
q1_off
+
h
*
stride_hq
,
mask
=
s_mask
)
q0_out
=
q0
*
cos
-
q1
*
sin
tl
.
store
(
OutQ
+
q0_off
+
h
*
stride_hq
,
q0_out
,
mask
=
s_mask
)
q1_out
=
q1
*
cos
+
q0
*
sin
tl
.
store
(
OutQ
+
q1_off
+
h
*
stride_hq
,
q1_out
,
mask
=
s_mask
)
for
h
in
range
(
BLOCK_HK
):
k0
=
tl
.
load
(
K
+
k0_off
+
h
*
stride_hk
,
mask
=
s_mask
)
k1
=
tl
.
load
(
K
+
k1_off
+
h
*
stride_hk
,
mask
=
s_mask
)
k0_out
=
k0
*
cos
-
k1
*
sin
tl
.
store
(
OutK
+
k0_off
+
h
*
stride_hk
,
k0_out
,
mask
=
s_mask
)
k1_out
=
k1
*
cos
+
k0
*
sin
tl
.
store
(
OutK
+
k1_off
+
h
*
stride_hk
,
k1_out
,
mask
=
s_mask
)
def
fused_rotary_emb
(
q
:
Tensor
,
k
:
Tensor
,
position_ids
:
torch
.
LongTensor
,
inv_freq
:
Tensor
,
scaling_factor
:
float
,
out_q
:
Tensor
=
None
,
out_k
:
Tensor
=
None
):
"""Fuse `rotary_embedding` and `apply_rotary_pos_emb`."""
def
_kernel_meta
():
device
=
q
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
if
out_q
is
None
:
out_q
=
torch
.
empty_like
(
q
)
else
:
assert
q
.
stride
()
==
out_q
.
stride
()
if
out_k
is
None
:
out_k
=
torch
.
empty_like
(
k
)
else
:
assert
k
.
stride
()
==
out_k
.
stride
()
assert
q
.
dim
()
==
4
assert
k
.
dim
()
==
4
assert
q
.
size
(
0
)
==
position_ids
.
size
(
0
)
BLOCK
=
32
BLOCK_HQ
=
q
.
size
(
-
2
)
BLOCK_HK
=
k
.
size
(
-
2
)
BLOCK_F
=
q
.
size
(
-
1
)
//
2
batch_size
=
q
.
size
(
0
)
max_seq_len
=
q
.
size
(
1
)
kernel_meta
=
_kernel_meta
()
num_warps
=
4
grid
=
(
batch_size
,
triton
.
cdiv
(
max_seq_len
,
BLOCK
))
_fused_rotary_emb_kernel
[
grid
](
q
,
k
,
position_ids
,
inv_freq
,
scaling_factor
,
out_q
,
out_k
,
stride_bq
=
q
.
stride
(
0
),
stride_sq
=
q
.
stride
(
1
),
stride_hq
=
q
.
stride
(
2
),
stride_dq
=
q
.
stride
(
3
),
stride_bk
=
k
.
stride
(
0
),
stride_sk
=
k
.
stride
(
1
),
stride_hk
=
k
.
stride
(
2
),
stride_dk
=
k
.
stride
(
3
),
stride_bp
=
position_ids
.
stride
(
0
),
stride_sp
=
position_ids
.
stride
(
1
),
max_seq_len
=
max_seq_len
,
BLOCK
=
BLOCK
,
BLOCK_HQ
=
BLOCK_HQ
,
BLOCK_HK
=
BLOCK_HK
,
BLOCK_F
=
BLOCK_F
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
return
out_q
,
out_k
lmdeploy/pytorch/kernels/mbgmm.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
def
_next_pow_of_2
(
x
):
"""get next power of 2."""
return
1
<<
(
x
-
1
).
bit_length
()
@
triton
.
jit
def
_x_a_mm_kernel
(
X
,
LoRA_A
,
XA
,
B_start_loc
,
B_seq_lens
,
B_adapter_id
,
Rank_page_table
,
Rank_page_start
,
Ranks
,
stride_xs
,
stride_xh
,
stride_las
,
stride_lah
,
stride_xas
,
stride_xar
,
stride_ptb
,
rank_step
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_R
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
"""xa mm kernel."""
cur_batch
=
tl
.
program_id
(
0
)
start_m
=
tl
.
program_id
(
1
)
r_off
=
tl
.
arange
(
0
,
BLOCK_R
)
seq_len
=
tl
.
load
(
B_seq_lens
+
cur_batch
)
if
start_m
*
BLOCK_M
>=
seq_len
:
return
start_loc
=
tl
.
load
(
B_start_loc
+
cur_batch
)
adapter_id
=
tl
.
load
(
B_adapter_id
+
cur_batch
)
rank
=
tl
.
load
(
Ranks
+
adapter_id
)
//
rank_step
page_start
=
tl
.
load
(
Rank_page_start
+
adapter_id
)
page_table_off
=
adapter_id
*
stride_ptb
+
r_off
+
page_start
rank_mask
=
r_off
<
rank
page_table
=
tl
.
load
(
Rank_page_table
+
page_table_off
,
mask
=
rank_mask
)
m_off
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
dm_off
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
x_off
=
(
start_loc
+
m_off
)
*
stride_xs
xs_mask
=
m_off
<
seq_len
la_page_off
=
page_table
*
stride_las
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_R
),
dtype
=
tl
.
float32
)
# compute acc
for
start_h
in
range
(
0
,
BLOCK_H
,
BLOCK_DMODEL
):
cur_dm_off
=
start_h
+
dm_off
h_mask
=
cur_dm_off
<
BLOCK_H
# load x
xh_off
=
cur_dm_off
*
stride_xh
x_mask
=
xs_mask
[:,
None
]
and
h_mask
[
None
,
:]
x
=
tl
.
load
(
X
+
x_off
[:,
None
]
+
xh_off
[
None
,
:],
mask
=
x_mask
,
other
=
0.0
)
# load lora a
lah_off
=
cur_dm_off
*
stride_lah
la_mask
=
rank_mask
[
None
,
:]
and
h_mask
[:,
None
]
la
=
tl
.
load
(
LoRA_A
+
la_page_off
[
None
,
:]
+
lah_off
[:,
None
],
mask
=
la_mask
,
other
=
0.0
)
# compute
acc
+=
tl
.
dot
(
x
,
la
)
acc
=
acc
.
to
(
X
.
dtype
.
element_ty
)
xa_off
=
(
start_loc
+
m_off
)
*
stride_xas
xas_mask
=
xs_mask
xa_mask
=
xas_mask
[:,
None
]
and
rank_mask
[
None
,
:]
tl
.
store
(
XA
+
xa_off
[:,
None
]
+
r_off
[
None
,
:]
*
stride_xar
,
acc
,
mask
=
xa_mask
)
@
triton
.
jit
def
_acc_b_mm_kernel
(
XA
,
LoRA_B
,
Out
,
B_start_loc
,
B_seq_lens
,
B_adapter_id
,
B_scaling
,
Rank_page_table
,
Rank_page_start
,
Ranks
,
stride_xas
,
stride_xar
,
stride_os
,
stride_oh
,
stride_lbs
,
stride_lbh
,
stride_ptb
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_R
:
tl
.
constexpr
,
BLOCK_HO
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
start_m
=
tl
.
program_id
(
1
)
r_off
=
tl
.
arange
(
0
,
BLOCK_R
)
seq_len
=
tl
.
load
(
B_seq_lens
+
cur_batch
)
if
start_m
*
BLOCK_M
>=
seq_len
:
return
start_loc
=
tl
.
load
(
B_start_loc
+
cur_batch
)
adapter_id
=
tl
.
load
(
B_adapter_id
+
cur_batch
)
scaling
=
tl
.
load
(
B_scaling
+
cur_batch
)
rank
=
tl
.
load
(
Ranks
+
adapter_id
)
page_start
=
tl
.
load
(
Rank_page_start
+
adapter_id
)
page_table_off
=
adapter_id
*
stride_ptb
+
r_off
+
page_start
rank_mask
=
r_off
<
rank
page_table
=
tl
.
load
(
Rank_page_table
+
page_table_off
,
mask
=
rank_mask
)
m_off
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
dm_off
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
lb_page_off
=
page_table
*
stride_lbs
xs_mask
=
m_off
<
seq_len
o_off
=
(
start_loc
+
m_off
)
*
stride_os
os_mask
=
xs_mask
xa_off
=
(
start_loc
+
m_off
)
*
stride_xas
xa_mask
=
xs_mask
[:,
None
]
and
rank_mask
[
None
,
:]
acc
=
tl
.
load
(
XA
+
xa_off
[:,
None
]
+
r_off
[
None
,
:]
*
stride_xar
,
mask
=
xa_mask
,
other
=
0.0
)
acc
=
acc
.
to
(
LoRA_B
.
dtype
.
element_ty
)
# compute output
for
start_h
in
range
(
0
,
BLOCK_HO
,
BLOCK_DMODEL
):
cur_dm_off
=
start_h
+
dm_off
h_mask
=
cur_dm_off
<
BLOCK_HO
# load lora b
lbh_off
=
cur_dm_off
*
stride_lbh
lb_mask
=
rank_mask
[:,
None
]
and
h_mask
[
None
,
:]
lb
=
tl
.
load
(
LoRA_B
+
lb_page_off
[:,
None
]
+
lbh_off
[
None
,
:],
mask
=
lb_mask
,
other
=
0
)
# compute
out
=
tl
.
dot
(
acc
,
lb
)
out
=
out
.
to
(
lb
.
dtype
)
out
=
out
*
scaling
# store o
oh_off
=
cur_dm_off
*
stride_oh
o_mask
=
os_mask
[:,
None
]
and
h_mask
[
None
,
:]
tl
.
store
(
Out
+
o_off
[:,
None
]
+
oh_off
[
None
,
:],
out
,
mask
=
o_mask
)
@
torch
.
inference_mode
()
def
mbgmm_a
(
x
:
Tensor
,
lora_a
:
Tensor
,
q_start_loc
:
Tensor
,
q_seqlens
:
Tensor
,
adapter_ids
:
Tensor
,
rank_page_table
:
Tensor
,
ranks
:
Tensor
,
rank_page_start
:
Tensor
,
max_seq_len
:
int
,
max_rank
:
int
,
rank_step
:
int
=
1
):
"""mbgmm_a."""
def
_kernel_meta
():
device
=
x
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
assert
x
.
dim
()
==
2
assert
lora_a
.
dim
()
==
2
assert
rank_page_table
.
dim
()
==
2
head_size
=
x
.
size
(
-
1
)
batch_size
=
len
(
q_seqlens
)
max_rank
=
max_rank
//
rank_step
BLOCK_M
=
32
BLOCK_R
=
_next_pow_of_2
(
max_rank
)
if
BLOCK_R
<
16
:
BLOCK_R
=
16
BLOCK_H
=
head_size
BLOCK_DMODEL
=
64
num_warps
=
4
grid
=
[
batch_size
,
triton
.
cdiv
(
max_seq_len
,
BLOCK_M
)]
xa
=
x
.
new_empty
((
x
.
size
(
0
),
max_rank
))
kernel_meta
=
_kernel_meta
()
_x_a_mm_kernel
[
grid
](
x
,
lora_a
,
xa
,
q_start_loc
,
q_seqlens
,
adapter_ids
,
Rank_page_table
=
rank_page_table
,
Rank_page_start
=
rank_page_start
,
Ranks
=
ranks
,
stride_xs
=
x
.
stride
(
0
),
stride_xh
=
x
.
stride
(
1
),
stride_las
=
lora_a
.
stride
(
0
),
stride_lah
=
lora_a
.
stride
(
1
),
stride_xas
=
xa
.
stride
(
0
),
stride_xar
=
xa
.
stride
(
1
),
stride_ptb
=
rank_page_table
.
stride
(
0
),
rank_step
=
rank_step
,
BLOCK_M
=
BLOCK_M
,
BLOCK_R
=
BLOCK_R
,
BLOCK_H
=
BLOCK_H
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
return
xa
@
torch
.
inference_mode
()
def
mbgmm_b
(
xa
:
Tensor
,
lora_b
:
Tensor
,
q_start_loc
:
Tensor
,
q_seqlens
:
Tensor
,
adapter_ids
:
Tensor
,
scaling
:
Tensor
,
rank_page_table
:
Tensor
,
ranks
:
Tensor
,
rank_page_start
:
Tensor
,
max_seq_len
:
int
,
max_rank
:
int
,
out_size
:
int
=
None
):
"""mbgmm_b."""
def
_kernel_meta
():
device
=
xa
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
assert
xa
.
dim
()
==
2
assert
lora_b
.
dim
()
==
2
assert
rank_page_table
.
dim
()
==
2
if
out_size
is
None
:
out_size
=
lora_b
.
size
(
-
1
)
batch_size
=
len
(
q_seqlens
)
BLOCK_M
=
32
BLOCK_R
=
_next_pow_of_2
(
max_rank
)
if
BLOCK_R
<
16
:
BLOCK_R
=
16
BLOCK_HO
=
out_size
BLOCK_DMODEL
=
64
num_warps
=
4
grid
=
[
batch_size
,
triton
.
cdiv
(
max_seq_len
,
BLOCK_M
)]
output
=
xa
.
new_empty
((
xa
.
size
(
0
),
BLOCK_HO
))
kernel_meta
=
_kernel_meta
()
_acc_b_mm_kernel
[
grid
](
xa
,
lora_b
,
output
,
q_start_loc
,
q_seqlens
,
adapter_ids
,
scaling
,
Rank_page_table
=
rank_page_table
,
Rank_page_start
=
rank_page_start
,
Ranks
=
ranks
,
stride_xas
=
xa
.
stride
(
0
),
stride_xar
=
xa
.
stride
(
1
),
stride_os
=
output
.
stride
(
0
),
stride_oh
=
output
.
stride
(
1
),
stride_lbs
=
lora_b
.
stride
(
0
),
stride_lbh
=
lora_b
.
stride
(
1
),
stride_ptb
=
rank_page_table
.
stride
(
0
),
BLOCK_M
=
BLOCK_M
,
BLOCK_R
=
BLOCK_R
,
BLOCK_HO
=
BLOCK_HO
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
return
output
lmdeploy/pytorch/kernels/mbgmv.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
def
_next_pow_of_2
(
x
):
"""get next power of 2."""
return
1
<<
(
x
-
1
).
bit_length
()
@
triton
.
jit
def
_x_a_mv_kernel
(
X
,
LoRA_A
,
XA
,
B_adapter_id
,
Rank_page_table
,
Rank_page_start
,
Ranks
,
stride_xs
,
stride_xh
,
stride_las
,
stride_lah
,
stride_xas
,
stride_xar
,
stride_ptb
,
rank_step
,
BLOCK_R
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
"""xa mv kernel."""
cur_batch
=
tl
.
program_id
(
0
)
r_off
=
tl
.
arange
(
0
,
BLOCK_R
)
adapter_id
=
tl
.
load
(
B_adapter_id
+
cur_batch
)
rank
=
tl
.
load
(
Ranks
+
adapter_id
)
//
rank_step
page_start
=
tl
.
load
(
Rank_page_start
+
adapter_id
)
page_table_off
=
adapter_id
*
stride_ptb
+
r_off
+
page_start
rank_mask
=
r_off
<
rank
page_table
=
tl
.
load
(
Rank_page_table
+
page_table_off
,
mask
=
rank_mask
)
dm_off
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
x_off
=
cur_batch
*
stride_xs
la_page_off
=
page_table
*
stride_las
acc
=
tl
.
zeros
((
BLOCK_R
,
),
dtype
=
tl
.
float32
)
# compute acc
for
start_h
in
range
(
0
,
BLOCK_H
,
BLOCK_DMODEL
):
cur_dm_off
=
start_h
+
dm_off
h_mask
=
cur_dm_off
<
BLOCK_H
# load x
xh_off
=
cur_dm_off
*
stride_xh
x_mask
=
h_mask
x
=
tl
.
load
(
X
+
x_off
+
xh_off
,
mask
=
x_mask
,
other
=
0.0
)
# load lora a
lah_off
=
cur_dm_off
*
stride_lah
la_mask
=
rank_mask
[:,
None
]
and
h_mask
[
None
,
:]
la
=
tl
.
load
(
LoRA_A
+
la_page_off
[:,
None
]
+
lah_off
[
None
,
:],
mask
=
la_mask
,
other
=
0.0
)
# compute
acc
+=
tl
.
sum
(
x
[
None
,
:]
*
la
,
1
)
acc
=
acc
.
to
(
X
.
dtype
.
element_ty
)
xa_off
=
cur_batch
*
stride_xas
tl
.
store
(
XA
+
xa_off
+
r_off
*
stride_xar
,
acc
,
mask
=
rank_mask
)
@
triton
.
jit
def
_acc_b_mv_kernel
(
XA
,
LoRA_B
,
Out
,
B_adapter_id
,
B_scaling
,
Rank_page_table
,
Rank_page_start
,
Ranks
,
stride_xas
,
stride_xar
,
stride_os
,
stride_oh
,
stride_lbs
,
stride_lbh
,
stride_ptb
,
BLOCK_R
:
tl
.
constexpr
,
BLOCK_HO
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
"""acc b mv kernel."""
cur_batch
=
tl
.
program_id
(
0
)
r_off
=
tl
.
arange
(
0
,
BLOCK_R
)
adapter_id
=
tl
.
load
(
B_adapter_id
+
cur_batch
)
scaling
=
tl
.
load
(
B_scaling
+
cur_batch
)
rank
=
tl
.
load
(
Ranks
+
adapter_id
)
page_start
=
tl
.
load
(
Rank_page_start
+
adapter_id
)
page_table_off
=
adapter_id
*
stride_ptb
+
r_off
+
page_start
rank_mask
=
r_off
<
rank
page_table
=
tl
.
load
(
Rank_page_table
+
page_table_off
,
mask
=
rank_mask
)
dm_off
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
lb_page_off
=
page_table
*
stride_lbs
o_off
=
cur_batch
*
stride_os
xa_off
=
cur_batch
*
stride_xas
acc
=
tl
.
load
(
XA
+
xa_off
+
r_off
*
stride_xar
,
mask
=
rank_mask
,
other
=
0.0
)
# compute output
for
start_h
in
range
(
0
,
BLOCK_HO
,
BLOCK_DMODEL
):
cur_dm_off
=
start_h
+
dm_off
h_mask
=
cur_dm_off
<
BLOCK_HO
# load lora b
lbh_off
=
cur_dm_off
*
stride_lbh
lb_mask
=
rank_mask
[:,
None
]
and
h_mask
[
None
,
:]
lb
=
tl
.
load
(
LoRA_B
+
lb_page_off
[:,
None
]
+
lbh_off
[
None
,
:],
mask
=
lb_mask
,
other
=
0
)
# compute
out
=
tl
.
sum
(
acc
[:,
None
]
*
lb
,
0
)
out
=
out
.
to
(
lb
.
dtype
)
out
=
out
*
scaling
# store o
oh_off
=
cur_dm_off
*
stride_oh
tl
.
store
(
Out
+
o_off
+
oh_off
,
out
,
mask
=
h_mask
)
@
torch
.
inference_mode
()
def
mbgmv_a
(
x
:
Tensor
,
lora_a
:
Tensor
,
adapter_ids
:
Tensor
,
rank_page_table
:
Tensor
,
ranks
:
Tensor
,
rank_page_start
:
Tensor
,
max_rank
:
int
,
rank_step
:
int
=
1
):
"""mbgmv_a."""
def
_kernel_meta
():
device
=
x
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
assert
x
.
dim
()
==
2
assert
lora_a
.
dim
()
==
2
assert
rank_page_table
.
dim
()
==
2
head_size
=
x
.
size
(
-
1
)
batch_size
=
x
.
size
(
0
)
max_rank
=
max_rank
//
rank_step
BLOCK_R
=
_next_pow_of_2
(
max_rank
)
BLOCK_H
=
head_size
BLOCK_DMODEL
=
512
num_warps
=
4
grid
=
[
batch_size
]
xa
=
x
.
new_empty
((
x
.
size
(
0
),
BLOCK_R
))
kernel_meta
=
_kernel_meta
()
_x_a_mv_kernel
[
grid
](
x
,
lora_a
,
xa
,
adapter_ids
,
Rank_page_table
=
rank_page_table
,
Rank_page_start
=
rank_page_start
,
Ranks
=
ranks
,
stride_xs
=
x
.
stride
(
0
),
stride_xh
=
x
.
stride
(
1
),
stride_las
=
lora_a
.
stride
(
0
),
stride_lah
=
lora_a
.
stride
(
1
),
stride_xas
=
xa
.
stride
(
0
),
stride_xar
=
xa
.
stride
(
1
),
stride_ptb
=
rank_page_table
.
stride
(
0
),
rank_step
=
rank_step
,
BLOCK_R
=
BLOCK_R
,
BLOCK_H
=
BLOCK_H
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
return
xa
@
torch
.
inference_mode
()
def
mbgmv_b
(
xa
:
Tensor
,
lora_b
:
Tensor
,
adapter_ids
:
Tensor
,
scaling
:
Tensor
,
rank_page_table
:
Tensor
,
ranks
:
Tensor
,
rank_page_start
:
Tensor
,
max_rank
:
int
,
out_size
:
int
=
None
):
"""mbgmv_b."""
def
_kernel_meta
():
device
=
xa
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
assert
xa
.
dim
()
==
2
assert
lora_b
.
dim
()
==
2
assert
rank_page_table
.
dim
()
==
2
if
out_size
is
None
:
out_size
=
lora_b
.
size
(
-
1
)
batch_size
=
xa
.
size
(
0
)
BLOCK_R
=
_next_pow_of_2
(
max_rank
)
BLOCK_HO
=
out_size
BLOCK_DMODEL
=
512
num_warps
=
4
grid
=
[
batch_size
]
output
=
xa
.
new_empty
((
xa
.
size
(
0
),
BLOCK_HO
))
kernel_meta
=
_kernel_meta
()
_acc_b_mv_kernel
[
grid
](
xa
,
lora_b
,
output
,
adapter_ids
,
scaling
,
Rank_page_table
=
rank_page_table
,
Rank_page_start
=
rank_page_start
,
Ranks
=
ranks
,
stride_xas
=
xa
.
stride
(
0
),
stride_xar
=
xa
.
stride
(
1
),
stride_lbs
=
lora_b
.
stride
(
0
),
stride_lbh
=
lora_b
.
stride
(
1
),
stride_os
=
output
.
stride
(
0
),
stride_oh
=
output
.
stride
(
1
),
stride_ptb
=
rank_page_table
.
stride
(
0
),
BLOCK_R
=
BLOCK_R
,
BLOCK_HO
=
BLOCK_HO
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
return
output
lmdeploy/pytorch/kernels/multinomial_sampling.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
triton.runtime.jit
import
get_cuda_stream
@
triton
.
jit
def
_multinomial_sampling_kernel
(
Scores
,
Seeds
,
Offsets
,
Indices
,
Outputs
,
stride_sb
,
stride_st
,
stride_ib
,
stride_it
,
num_batchs
,
num_tokens
,
BLOCK
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
):
"""Kernel."""
batch_block_id
=
tl
.
program_id
(
0
)
off
=
batch_block_id
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
n_off
=
tl
.
arange
(
0
,
BLOCK_N
)
off_mask
=
off
<
num_batchs
seed
=
tl
.
load
(
Seeds
+
off
,
mask
=
off_mask
)
offset
=
tl
.
load
(
Offsets
+
off
,
mask
=
off_mask
).
to
(
tl
.
int32
)
samp
=
tl
.
rand
(
seed
,
offset
)[:,
None
]
acc
=
tl
.
zeros
((
BLOCK
,
),
dtype
=
tl
.
float32
)
output
=
tl
.
load
(
Indices
+
off
*
stride_ib
,
mask
=
off_mask
)
for
b_idx
in
range
(
0
,
num_tokens
,
BLOCK_N
):
s_off
=
b_idx
+
n_off
s_mask
=
off_mask
[:,
None
]
&
(
s_off
[
None
,
:]
<
num_tokens
)
scores
=
tl
.
load
(
Scores
+
off
[:,
None
]
*
stride_sb
+
s_off
[
None
,
:]
*
stride_st
,
mask
=
s_mask
,
other
=
0.0
).
to
(
acc
.
dtype
)
cum_scores
=
acc
[:,
None
]
+
tl
.
cumsum
(
scores
,
1
)
acc
+=
tl
.
sum
(
scores
,
1
)
pre_cum_scores
=
cum_scores
-
scores
valid_mask
=
(
samp
>
pre_cum_scores
)
&
(
samp
<=
cum_scores
)
found_mask
=
tl
.
sum
(
valid_mask
,
1
)
>
0
valid_pos
=
b_idx
+
tl
.
argmax
(
valid_mask
.
to
(
tl
.
int32
),
1
)
indices
=
tl
.
load
(
Indices
+
off
*
stride_ib
+
valid_pos
*
stride_it
,
mask
=
found_mask
&
off_mask
,
other
=-
1
)
output
=
tl
.
where
(
found_mask
,
indices
,
output
)
tl
.
store
(
Outputs
+
off
,
output
,
mask
=
off_mask
)
def
multinomial_sampling
(
scores
:
torch
.
Tensor
,
seeds
:
torch
.
LongTensor
,
offsets
:
torch
.
LongTensor
,
indices
:
torch
.
Tensor
=
None
):
"""multinomial sampling."""
def
__kernel_meta
():
"""kernel meta."""
device
=
scores
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
assert
scores
.
dim
()
==
2
batch_size
,
num_tokens
=
scores
.
size
()
device
=
scores
.
device
if
num_tokens
==
1
:
return
torch
.
zeros_like
(
scores
,
dtype
=
torch
.
long
)
if
indices
is
None
:
indices
=
torch
.
arange
(
num_tokens
,
device
=
device
)
indices
=
indices
.
expand_as
(
scores
)
assert
indices
.
dim
()
==
2
assert
indices
.
size
()
==
scores
.
size
()
outputs
=
indices
[:,
0
].
clone
()
BLOCK
=
32
BLOCK_N
=
64
grid
=
[
triton
.
cdiv
(
batch_size
,
BLOCK
)]
kernel_meta
=
__kernel_meta
()
_multinomial_sampling_kernel
[
grid
](
scores
,
seeds
,
offsets
,
indices
,
outputs
,
stride_sb
=
scores
.
stride
(
0
),
stride_st
=
scores
.
stride
(
1
),
stride_ib
=
indices
.
stride
(
0
),
stride_it
=
indices
.
stride
(
1
),
num_batchs
=
batch_size
,
num_tokens
=
num_tokens
,
BLOCK
=
BLOCK
,
BLOCK_N
=
BLOCK_N
,
**
kernel_meta
)
return
outputs
lmdeploy/pytorch/kernels/pagedattention.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
# modify from: https://github.com/ModelTC/lightllm
import
torch
import
triton
import
triton.language
as
tl
from
packaging
import
version
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
TRITON_VERSION
=
version
.
parse
(
triton
.
__version__
)
assert
TRITON_VERSION
>=
version
.
parse
(
'2.1.0'
)
if
TRITON_VERSION
>=
version
.
parse
(
'2.2.0'
):
@
triton
.
jit
def
_load_block_offsets
(
offset_ptr
,
block_id
,
num_sub_blocks
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
):
"""load block offsets."""
if
num_sub_blocks
>
1
:
offs_sub
=
tl
.
arange
(
0
,
num_sub_blocks
)
offs_n
=
tl
.
arange
(
0
,
BLOCK
//
num_sub_blocks
)
ret
=
tl
.
load
(
offset_ptr
+
block_id
*
num_sub_blocks
+
offs_sub
)[:,
None
]
*
BLOCK
//
num_sub_blocks
+
offs_n
[
None
,
:]
return
tl
.
ravel
(
ret
)
else
:
offs_n
=
tl
.
arange
(
0
,
BLOCK
)
return
tl
.
load
(
offset_ptr
+
block_id
)
*
BLOCK
+
offs_n
else
:
@
triton
.
jit
def
_load_block_offsets
(
offset_ptr
,
block_id
,
num_sub_blocks
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
):
"""load block offsets triton<2.2.0."""
if
num_sub_blocks
>
1
:
offs_sub
=
tl
.
arange
(
0
,
num_sub_blocks
)
offs_n
=
tl
.
arange
(
0
,
BLOCK
//
num_sub_blocks
)
ret
=
tl
.
load
(
offset_ptr
+
block_id
*
num_sub_blocks
+
offs_sub
)[
None
,
:]
*
BLOCK
//
num_sub_blocks
+
offs_n
[:,
None
]
return
tl
.
ravel
(
ret
)
else
:
offs_n
=
tl
.
arange
(
0
,
BLOCK
)
return
tl
.
load
(
offset_ptr
+
block_id
)
*
BLOCK
+
offs_n
@
triton
.
jit
def
_fwd_split_kernel
(
Q
,
K
,
V
,
sm_scale
,
KV_seqlens
,
Block_offsets
,
Acc_out
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_ok
,
stride_obs
,
stride_oh
,
stride_od
,
stride_boffb
,
kv_group_num
,
block_per_cta
,
window_size
:
tl
.
constexpr
,
num_sub_blocks
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""first step kernel of split k attention."""
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
split_k_id
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
q_seqlen
=
1
kv_seqlen
=
tl
.
load
(
KV_seqlens
+
cur_batch
)
history_len
=
kv_seqlen
-
q_seqlen
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
off_q
=
(
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
*
stride_qd
)
off_k
=
(
cur_kv_head
*
stride_kh
+
offs_d
[
None
,
:]
*
stride_kd
)
off_v
=
(
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
q
=
tl
.
load
(
Q
+
off_q
).
to
(
tl
.
float32
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_offset_ptrs
=
Block_offsets
+
cur_batch
*
stride_boffb
# initialize pointer to m and l
m_i
=
-
float
(
'inf'
)
l_i
=
float
(
0
)
acc
=
tl
.
zeros
([
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
kv_len_per_prog
=
block_per_cta
*
BLOCK_N
loop_start
=
kv_len_per_prog
*
split_k_id
loop_end
=
tl
.
minimum
(
loop_start
+
kv_len_per_prog
,
kv_seqlen
)
# load block offset
# dirty
start_block_id
=
loop_start
//
BLOCK_N
if
window_size
>
0
:
start_block_id
=
tl
.
maximum
(
history_len
-
window_size
,
loop_start
)
//
BLOCK_N
kv_min_loc
=
tl
.
maximum
(
history_len
-
window_size
,
0
)
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
start_block_id
,
num_sub_blocks
,
BLOCK_N
)
loop_start
=
start_block_id
*
BLOCK_N
for
start_n
in
range
(
loop_start
,
loop_end
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
mask
=
(
start_n
+
offs_n
[:,
None
])
<
kv_seqlen
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
b_offset
[:,
None
]
*
stride_kbs
,
mask
=
mask
,
other
=
0.0
,
)
v
=
tl
.
load
(
v_ptrs
+
b_offset
[:,
None
]
*
stride_vbs
,
mask
=
mask
,
other
=
0.0
,
)
# prefetch b_offset
if
start_n
+
BLOCK_N
<
loop_end
:
start_block_id
+=
1
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
start_block_id
,
num_sub_blocks
,
BLOCK_N
)
qk
=
tl
.
sum
(
q
[
None
,
:]
*
k
,
1
)
qk
*=
sm_scale
# NOTE: inf - inf = nan, and nan will leads to error
qk_mask
=
history_len
>=
(
start_n
+
offs_n
)
if
window_size
>
0
:
qk_mask
=
qk_mask
and
((
start_n
+
offs_n
)
>=
kv_min_loc
)
qk
=
tl
.
where
(
qk_mask
,
qk
,
-
float
(
'inf'
),
)
# -- compute p, m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
0
))
p
=
tl
.
exp
(
qk
-
m_i_new
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
tl
.
sum
(
p
,
0
)
# -- update output accumulator --
# scale acc
acc
=
acc
*
alpha
# update acc
p_new
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
sum
(
p_new
[:,
None
]
*
v
,
0
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_acc
=
(
cur_batch
*
stride_obs
+
split_k_id
*
stride_ok
+
cur_head
*
stride_oh
+
offs_d
*
stride_od
)
tl
.
store
(
Acc_out
+
off_acc
,
acc
)
off_meta
=
(
cur_batch
*
stride_obs
+
split_k_id
*
stride_ok
+
cur_head
*
stride_oh
+
BLOCK_DMODEL
)
tl
.
store
(
Acc_out
+
off_meta
+
tl
.
arange
(
0
,
1
),
m_i
)
tl
.
store
(
Acc_out
+
off_meta
+
1
+
tl
.
arange
(
0
,
1
),
l_i
)
@
triton
.
jit
def
_reduce_split_kernel
(
Acc
,
Out
,
stride_ak
,
stride_abs
,
stride_ah
,
stride_ad
,
stride_obs
,
stride_oh
,
stride_od
,
SPLIT_K
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
):
"""second step kernel of split k attention."""
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
# initialize offsets
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_k
=
tl
.
arange
(
0
,
SPLIT_K
)
offs_acc
=
(
cur_batch
*
stride_abs
+
cur_head
*
stride_ah
+
offs_k
[:,
None
]
*
stride_ak
+
offs_d
[
None
,
:]
*
stride_ad
)
offs_mi
=
(
cur_batch
*
stride_abs
+
cur_head
*
stride_ah
+
stride_ak
*
offs_k
+
BLOCK_DMODEL
)
acc_k
=
tl
.
load
(
Acc
+
offs_acc
)
m_k
=
tl
.
load
(
Acc
+
offs_mi
)
l_k
=
tl
.
load
(
Acc
+
offs_mi
+
1
)
m_max
=
tl
.
max
(
m_k
,
0
)
alpha
=
tl
.
exp
(
m_k
-
m_max
)
acc_k
=
acc_k
*
alpha
[:,
None
]
l_k
=
l_k
*
alpha
acc
=
tl
.
sum
(
acc_k
,
0
)
l_sum
=
tl
.
sum
(
l_k
,
0
)
acc
=
acc
/
l_sum
out_offs
=
(
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
*
stride_od
)
tl
.
store
(
Out
+
out_offs
,
acc
)
def
_get_convert_pv
(
nv_capability
):
"""lazy load convert_pv."""
if
nv_capability
[
0
]
>=
8
:
@
triton
.
jit
def
convert_pv
(
p
,
v
):
"""convert pv."""
p
=
p
.
to
(
v
.
dtype
)
return
p
,
v
else
:
@
triton
.
jit
def
convert_pv
(
p
,
v
):
"""convert pv."""
v
=
v
.
to
(
p
.
dtype
)
return
p
,
v
return
convert_pv
_convert_pv
=
None
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
sm_scale
,
Q_start_loc
,
Q_seqlens
,
KV_seqlens
,
Block_offsets
,
Out
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_obs
,
stride_oh
,
stride_od
,
stride_boffb
,
kv_group_num
,
window_size
:
tl
.
constexpr
,
num_sub_blocks
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
"""paged attention kernel."""
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_kv_head
=
cur_head
//
kv_group_num
q_seqlen
=
tl
.
load
(
Q_seqlens
+
cur_batch
)
kv_seqlen
=
tl
.
load
(
KV_seqlens
+
cur_batch
)
q_start_loc
=
tl
.
load
(
Q_start_loc
+
cur_batch
)
history_len
=
kv_seqlen
-
q_seqlen
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
((
q_start_loc
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
off_k
=
(
cur_kv_head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
cur_kv_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
,
other
=
0.0
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_offset_ptrs
=
Block_offsets
+
cur_batch
*
stride_boffb
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'inf'
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
block_mask
=
tl
.
where
(
block_start_loc
<
q_seqlen
,
1
,
0
)
# this is dirty
start_block_id
=
kv_seqlen
-
kv_seqlen
if
window_size
>
0
:
start_block_id
=
tl
.
maximum
(
history_len
-
window_size
,
0
)
//
BLOCK_N
kv_min_loc
=
tl
.
maximum
(
history_len
+
offs_m
-
window_size
,
0
)
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
start_block_id
,
num_sub_blocks
,
BLOCK_N
)
kv_start_loc
=
start_block_id
*
BLOCK_N
for
start_n
in
range
(
kv_start_loc
,
block_mask
*
kv_seqlen
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
b_offset
[
None
,
:]
*
stride_kbs
,
mask
=
start_n
+
offs_n
[
None
,
:]
<
kv_seqlen
,
other
=
0.0
,
)
v
=
tl
.
load
(
v_ptrs
+
b_offset
[:,
None
]
*
stride_vbs
,
mask
=
start_n
+
offs_n
[:,
None
]
<
kv_seqlen
,
other
=
0.0
,
)
if
start_n
+
BLOCK_N
<
kv_seqlen
:
start_block_id
=
start_n
//
BLOCK_N
+
1
b_offset
=
_load_block_offsets
(
block_offset_ptrs
,
start_block_id
,
num_sub_blocks
,
BLOCK_N
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
# NOTE: inf - inf = nan, and nan will leads to error
qk_mask
=
(
history_len
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:])
if
window_size
>
0
:
qk_mask
=
qk_mask
and
(
(
start_n
+
offs_n
[
None
,
:])
>=
kv_min_loc
[:,
None
])
qk
=
tl
.
where
(
qk_mask
,
qk
,
float
(
-
1e30
),
)
# -- compute p, m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
p
=
tl
.
exp
(
qk
-
m_i_new
[:,
None
])
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
tl
.
sum
(
p
,
1
)
# -- update output accumulator --
# scale acc
acc
=
acc
*
alpha
[:,
None
]
# update acc
p
,
v
=
_convert_pv
(
p
,
v
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
acc
=
acc
/
l_i
[:,
None
]
# initialize pointers to output
off_o
=
((
q_start_loc
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
q_seqlen
)
@
torch
.
inference_mode
()
def
paged_attention_fwd
(
q
:
Tensor
,
k
:
Tensor
,
v
:
Tensor
,
o
:
Tensor
,
block_offsets
:
Tensor
,
q_start_loc
:
Tensor
,
q_seqlens
:
Tensor
,
kv_seqlens
:
Tensor
,
max_seqlen
:
int
,
window_size
:
int
=
-
1
,
):
"""Paged Attention forward.
Args:
q (Tensor): Query state.
k (Tensor): Key state caches.
v (Tensor): Value state caches.
o (Tensor): Output state.
block_offsets (Tensor): The block offset of key and value.
q_start_loc (Tensor): Start token location of each data in batch.
q_seqlens (Tensor): Query length for each data in batch.
kv_seqlens (Tensor): Key/Value length for each data in batch.
max_seqlen (int): The max input length.
BLOCK (int): The kernel block size.
"""
global
_convert_pv
if
_convert_pv
is
None
:
nv_cap
=
torch
.
cuda
.
get_device_capability
()
_convert_pv
=
_get_convert_pv
(
nv_cap
)
def
_kernel_meta
():
"""kernel meta."""
device
=
q
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
,
256
}
sm_scale
=
1.0
/
(
Lq
**
0.5
)
# 计算scale系数
batch
,
head
=
q_seqlens
.
shape
[
0
],
q
.
shape
[
-
2
]
kv_group_num
=
q
.
shape
[
-
2
]
//
k
.
shape
[
-
2
]
num_warps
=
4
if
Lk
<=
64
else
8
BLOCK
=
64
if
k
.
size
(
1
)
<
16
else
k
.
size
(
1
)
num_sub_blocks
=
BLOCK
//
k
.
size
(
1
)
kernel_meta
=
_kernel_meta
()
is_decoding
=
q
.
shape
[
-
3
]
==
q_seqlens
.
size
(
0
)
if
not
is_decoding
:
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_seqlen
,
BLOCK
))
_fwd_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
q_start_loc
,
q_seqlens
,
kv_seqlens
,
block_offsets
,
o
,
stride_qbs
=
q
.
stride
(
-
3
),
stride_qh
=
q
.
stride
(
-
2
),
stride_qd
=
q
.
stride
(
-
1
),
stride_kbs
=
k
.
stride
(
-
3
),
stride_kh
=
k
.
stride
(
-
2
),
stride_kd
=
k
.
stride
(
-
1
),
stride_vbs
=
v
.
stride
(
-
3
),
stride_vh
=
v
.
stride
(
-
2
),
stride_vd
=
v
.
stride
(
-
1
),
stride_obs
=
o
.
stride
(
-
3
),
stride_oh
=
o
.
stride
(
-
2
),
stride_od
=
o
.
stride
(
-
1
),
stride_boffb
=
block_offsets
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
window_size
=
window_size
,
num_sub_blocks
=
num_sub_blocks
,
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
else
:
SPLIT_K
=
4
grid
=
(
batch
,
head
,
SPLIT_K
)
block_per_cta
=
triton
.
cdiv
(
block_offsets
.
size
(
-
1
),
SPLIT_K
)
acc
=
q
.
new_empty
(
batch
,
head
,
SPLIT_K
,
Lq
+
2
,
dtype
=
torch
.
float32
)
_fwd_split_kernel
[
grid
](
q
,
k
,
v
,
sm_scale
,
kv_seqlens
,
block_offsets
,
acc
,
stride_qbs
=
q
.
stride
(
-
3
),
stride_qh
=
q
.
stride
(
-
2
),
stride_qd
=
q
.
stride
(
-
1
),
stride_kbs
=
k
.
stride
(
-
3
),
stride_kh
=
k
.
stride
(
-
2
),
stride_kd
=
k
.
stride
(
-
1
),
stride_vbs
=
v
.
stride
(
-
3
),
stride_vh
=
v
.
stride
(
-
2
),
stride_vd
=
v
.
stride
(
-
1
),
stride_ok
=
acc
.
stride
(
-
2
),
stride_obs
=
acc
.
stride
(
-
4
),
stride_oh
=
acc
.
stride
(
-
3
),
stride_od
=
acc
.
stride
(
-
1
),
stride_boffb
=
block_offsets
.
stride
(
0
),
kv_group_num
=
kv_group_num
,
block_per_cta
=
block_per_cta
,
window_size
=
window_size
,
num_sub_blocks
=
num_sub_blocks
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
num_warps
=
4
,
num_stages
=
1
,
**
kernel_meta
)
grid
=
(
batch
,
head
)
_reduce_split_kernel
[
grid
](
acc
,
o
,
stride_ak
=
acc
.
stride
(
-
2
),
stride_abs
=
acc
.
stride
(
-
4
),
stride_ah
=
acc
.
stride
(
-
3
),
stride_ad
=
acc
.
stride
(
-
1
),
stride_obs
=
o
.
stride
(
-
3
),
stride_oh
=
o
.
stride
(
-
2
),
stride_od
=
o
.
stride
(
-
1
),
SPLIT_K
=
SPLIT_K
,
BLOCK_DMODEL
=
Lk
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
lmdeploy/pytorch/kernels/rearange_all_gather.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
triton.runtime.jit
import
get_cuda_stream
@
triton
.
jit
def
_rearange_all_gather_kernel
(
X
,
StartLoc
,
SeqLen
,
AdapterIds
,
Ranks
,
Out
,
stride_x
,
stride_o
,
world_size
,
BLOCK
:
tl
.
constexpr
,
BLOCK_P
:
tl
.
constexpr
):
"""rearange all gather kernel."""
batch_id
=
tl
.
program_id
(
0
)
block_id
=
tl
.
program_id
(
1
)
start_loc
=
tl
.
load
(
StartLoc
+
batch_id
)
+
block_id
*
BLOCK
seq_len
=
tl
.
load
(
SeqLen
+
batch_id
)
if
block_id
*
BLOCK
>=
seq_len
:
return
block_off
=
start_loc
+
tl
.
arange
(
0
,
BLOCK
)
block_mask
=
block_id
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
<
seq_len
adapter_id
=
tl
.
load
(
AdapterIds
+
batch_id
)
rank
=
tl
.
load
(
Ranks
+
adapter_id
)
prank
=
rank
//
world_size
p_off
=
tl
.
arange
(
0
,
BLOCK_P
)
for
p_id
in
range
(
world_size
):
ip_off
=
p_id
*
BLOCK_P
+
p_off
i_mask
=
block_mask
[:,
None
]
and
(
p_off
<
prank
)[
None
,
:]
i_off
=
block_off
[:,
None
]
*
stride_x
+
ip_off
[
None
,
:]
x
=
tl
.
load
(
X
+
i_off
,
mask
=
i_mask
)
op_off
=
p_id
*
prank
+
p_off
o_mask
=
i_mask
o_off
=
block_off
[:,
None
]
*
stride_o
+
op_off
[
None
,
:]
tl
.
store
(
Out
+
o_off
,
x
,
mask
=
o_mask
)
@
triton
.
jit
def
_rearange_all_gather_decoding_kernel
(
X
,
AdapterIds
,
Ranks
,
Out
,
stride_x
,
stride_o
,
world_size
,
seq_len
,
BLOCK
:
tl
.
constexpr
,
BLOCK_P
:
tl
.
constexpr
):
"""rearange all gather kernel."""
block_id
=
tl
.
program_id
(
0
)
block_off
=
block_id
*
BLOCK
+
tl
.
arange
(
0
,
BLOCK
)
block_mask
=
block_off
<
seq_len
adapter_ids
=
tl
.
load
(
AdapterIds
+
block_off
,
mask
=
block_mask
)
ranks
=
tl
.
load
(
Ranks
+
adapter_ids
)
pranks
=
ranks
//
world_size
p_off
=
tl
.
arange
(
0
,
BLOCK_P
)
for
p_id
in
range
(
world_size
):
ip_off
=
p_id
*
BLOCK_P
+
p_off
i_mask
=
block_mask
[:,
None
]
and
(
p_off
[
None
,
:]
<
pranks
[:,
None
])
i_off
=
block_off
[:,
None
]
*
stride_x
+
ip_off
[
None
,
:]
x
=
tl
.
load
(
X
+
i_off
,
mask
=
i_mask
)
op_off
=
p_id
*
pranks
[:,
None
]
+
p_off
[
None
,
:]
o_mask
=
i_mask
o_off
=
block_off
[:,
None
]
*
stride_o
+
op_off
tl
.
store
(
Out
+
o_off
,
x
,
mask
=
o_mask
)
def
rearange_all_gather
(
x
:
torch
.
Tensor
,
b_start_loc
:
torch
.
Tensor
,
b_seq_lens
:
torch
.
Tensor
,
adapter_ids
:
torch
.
LongTensor
,
ranks
:
torch
.
Tensor
,
world_size
:
int
,
max_seq_len
:
int
,
output
:
torch
.
Tensor
=
None
):
"""rearange all gather."""
def
_kernel_meta
():
device
=
x
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
max_rank
=
x
.
size
(
1
)
batch_size
=
len
(
b_seq_lens
)
partition_size
=
max_rank
//
world_size
if
output
is
None
:
output
=
torch
.
empty_like
(
x
)
num_warps
=
4
kernel_meta
=
_kernel_meta
()
is_decoding
=
batch_size
==
x
.
size
(
0
)
if
not
is_decoding
:
BLOCK
=
128
BLOCK_P
=
partition_size
grid
=
(
batch_size
,
triton
.
cdiv
(
max_seq_len
,
BLOCK
))
_rearange_all_gather_kernel
[
grid
](
x
,
b_start_loc
,
b_seq_lens
,
adapter_ids
,
ranks
,
output
,
stride_x
=
x
.
stride
(
0
),
stride_o
=
output
.
stride
(
0
),
world_size
=
world_size
,
BLOCK
=
BLOCK
,
BLOCK_P
=
BLOCK_P
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
else
:
BLOCK
=
64
BLOCK_P
=
partition_size
seq_len
=
x
.
size
(
0
)
grid
=
(
triton
.
cdiv
(
seq_len
,
BLOCK
),
)
_rearange_all_gather_decoding_kernel
[
grid
](
x
,
adapter_ids
,
ranks
,
output
,
stride_x
=
x
.
stride
(
0
),
stride_o
=
output
.
stride
(
0
),
world_size
=
world_size
,
seq_len
=
seq_len
,
BLOCK
=
BLOCK
,
BLOCK_P
=
BLOCK_P
,
num_warps
=
num_warps
,
num_stages
=
1
,
**
kernel_meta
)
return
output
lmdeploy/pytorch/kernels/rerope_attention.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
assert
triton
.
__version__
>=
'2.1.0'
# bugfix from https://gist.github.com/chu-tianxiang/4307937fd94b49c75b61a6967716bae9#file-rerope-py # noqa: E501
@
triton
.
jit
def
_rerope_fwd_kernel
(
Q1
,
Q2
,
K1
,
K2
,
V
,
sm_scale
,
# L,
Out
,
stride_qz
,
stride_qh
,
stride_qm
,
stride_qk
,
stride_kz
,
stride_kh
,
stride_kn
,
stride_kk
,
stride_vz
,
stride_vh
,
stride_vk
,
stride_vn
,
stride_oz
,
stride_oh
,
stride_om
,
stride_on
,
Z
,
H
,
N_CTX
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
WINDOW
:
tl
.
constexpr
,
):
"""rerope attention triton kernel."""
start_m
=
tl
.
program_id
(
0
)
off_hz
=
tl
.
program_id
(
1
)
q_offset
=
off_hz
*
stride_qh
kv_offset
=
off_hz
*
stride_kh
Q1_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q1
+
q_offset
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_qm
,
stride_qk
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
))
Q2_block_ptr
=
tl
.
make_block_ptr
(
base
=
Q2
+
q_offset
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_qm
,
stride_qk
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
))
K1_block_ptr
=
tl
.
make_block_ptr
(
base
=
K1
+
kv_offset
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_kk
,
stride_kn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
order
=
(
0
,
1
))
K2_block_ptr
=
tl
.
make_block_ptr
(
base
=
K2
+
kv_offset
,
shape
=
(
BLOCK_DMODEL
,
N_CTX
),
strides
=
(
stride_kk
,
stride_kn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_DMODEL
,
BLOCK_N
),
order
=
(
0
,
1
))
V_block_ptr
=
tl
.
make_block_ptr
(
base
=
V
+
kv_offset
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_vk
,
stride_vn
),
offsets
=
(
0
,
0
),
block_shape
=
(
BLOCK_N
,
BLOCK_DMODEL
),
order
=
(
1
,
0
))
# initialize offsets
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
# initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
'inf'
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
# scale sm_scale by log_2(e) and use
# 2^x instead of exp in the loop because CSE and LICM
# don't work as expected with `exp` in the loop
qk_scale
=
sm_scale
*
1.44269504
# load q: it will stay in SRAM throughout
q1
=
tl
.
load
(
Q1_block_ptr
,
boundary_check
=
(
0
,
1
))
dtype
=
q1
.
dtype
q1
=
(
q1
*
qk_scale
).
to
(
dtype
)
q2
=
tl
.
load
(
Q2_block_ptr
,
boundary_check
=
(
0
,
1
))
q2
=
(
q2
*
qk_scale
).
to
(
dtype
)
# loop over k, v and update accumulator
lo
=
0
hi
=
(
start_m
+
1
)
*
BLOCK_M
if
IS_CAUSAL
else
N_CTX
for
start_n
in
range
(
lo
,
hi
,
BLOCK_N
):
# -- compute qk ---
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
if
IS_CAUSAL
:
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
'-inf'
))
if
start_n
<=
start_m
*
BLOCK_M
-
WINDOW
-
BLOCK_N
or
start_n
>=
(
start_m
+
1
)
*
BLOCK_M
+
WINDOW
:
k2
=
tl
.
load
(
K2_block_ptr
)
v
=
tl
.
load
(
V_block_ptr
)
qk
+=
tl
.
dot
(
q2
,
k2
,
out_dtype
=
tl
.
float32
)
elif
start_n
>
(
start_m
+
1
)
*
BLOCK_M
-
WINDOW
and
start_n
<
start_m
*
BLOCK_M
+
WINDOW
-
BLOCK_N
:
# noqa: E501
k1
=
tl
.
load
(
K1_block_ptr
)
v
=
tl
.
load
(
V_block_ptr
)
qk
+=
tl
.
dot
(
q1
,
k1
,
out_dtype
=
tl
.
float32
)
else
:
k1
=
tl
.
load
(
K1_block_ptr
)
k2
=
tl
.
load
(
K2_block_ptr
)
v
=
tl
.
load
(
V_block_ptr
)
qk1
=
tl
.
dot
(
q1
,
k1
,
out_dtype
=
tl
.
float32
)
qk2
=
tl
.
dot
(
q2
,
k2
,
out_dtype
=
tl
.
float32
)
qk
+=
tl
.
where
(
tl
.
abs
(
offs_m
[:,
None
]
-
(
start_n
+
offs_n
[
None
,
:]))
<
WINDOW
,
qk1
,
qk2
)
# -- compute scaling constant ---
m_i_new
=
tl
.
maximum
(
m_i
,
tl
.
max
(
qk
,
1
))
alpha
=
tl
.
math
.
exp2
(
m_i
-
m_i_new
)
p
=
tl
.
math
.
exp2
(
qk
-
m_i_new
[:,
None
])
# -- scale and update acc --
acc_scale
=
l_i
*
0
+
alpha
# workaround some compiler bug
acc
*=
acc_scale
[:,
None
]
acc
+=
tl
.
dot
(
p
,
v
.
to
(
tl
.
float32
))
# -- update m_i and l_i --
l_i
=
l_i
*
alpha
+
tl
.
sum
(
p
,
1
)
m_i
=
m_i_new
# update pointers
K1_block_ptr
=
tl
.
advance
(
K1_block_ptr
,
(
0
,
BLOCK_N
))
K2_block_ptr
=
tl
.
advance
(
K2_block_ptr
,
(
0
,
BLOCK_N
))
V_block_ptr
=
tl
.
advance
(
V_block_ptr
,
(
BLOCK_N
,
0
))
# write back l and m
acc
=
acc
/
l_i
[:,
None
]
# debug softmax output
# l_ptrs = L + off_hz * N_CTX + offs_m
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))
# write back O
O_block_ptr
=
tl
.
make_block_ptr
(
base
=
Out
+
q_offset
,
shape
=
(
N_CTX
,
BLOCK_DMODEL
),
strides
=
(
stride_om
,
stride_on
),
offsets
=
(
start_m
*
BLOCK_M
,
0
),
block_shape
=
(
BLOCK_M
,
BLOCK_DMODEL
),
order
=
(
1
,
0
))
tl
.
store
(
O_block_ptr
,
acc
.
to
(
dtype
),
boundary_check
=
(
0
,
1
))
def
rerope_attention_fwd
(
q1
,
q2
,
k1
,
k2
,
v
,
causal
,
sm_scale
,
window
,
BLOCK_M
=
64
):
"""rerope attention forward."""
# shape constraints
Lq
,
Lk
,
Lv
=
q1
.
shape
[
-
1
],
k1
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
o
=
torch
.
empty_like
(
q1
)
BLOCK_N
=
64
if
Lk
<=
64
else
32
num_stages
=
4
if
Lk
<=
64
else
3
num_warps
=
4
grid
=
(
triton
.
cdiv
(
q1
.
shape
[
2
],
BLOCK_M
),
q1
.
shape
[
0
]
*
q1
.
shape
[
1
],
1
)
# L = torch.empty((q1.shape[0] * q1.shape[1], q1.shape[2]),
# device=q1.device,
# dtype=torch.float32)
_rerope_fwd_kernel
[
grid
](
q1
,
q2
,
k1
,
k2
,
v
,
sm_scale
,
# L,
o
,
q1
.
stride
(
0
),
q1
.
stride
(
1
),
q1
.
stride
(
2
),
q1
.
stride
(
3
),
k1
.
stride
(
0
),
k1
.
stride
(
1
),
k1
.
stride
(
2
),
k1
.
stride
(
3
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
v
.
stride
(
3
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
o
.
stride
(
3
),
q1
.
shape
[
0
],
q1
.
shape
[
1
],
q1
.
shape
[
2
],
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_DMODEL
=
Lk
,
IS_CAUSAL
=
causal
,
WINDOW
=
window
,
num_warps
=
num_warps
,
num_stages
=
num_stages
)
return
o
if
__name__
==
'__main__'
:
def
test_rerope
():
import
torch.utils.benchmark
as
benchmark
Z
=
1
H
=
40
N_CTX
=
2176
D_HEAD
=
128
WINDOW
=
512
sm_scale
=
0.0883883
def
torch_attention
(
q1
,
q2
,
k1
,
k2
,
v
,
causal
,
sm_scale
,
window
):
# reference implementation
M
=
torch
.
tril
(
torch
.
ones
((
N_CTX
,
N_CTX
),
device
=
'cuda'
))
p1
=
torch
.
matmul
(
q1
,
k1
.
transpose
(
2
,
3
))
*
sm_scale
p2
=
torch
.
matmul
(
q2
,
k2
.
transpose
(
2
,
3
))
*
sm_scale
if
causal
:
p1
[:,
:,
M
==
0
]
=
float
(
'-inf'
)
p2
[:,
:,
M
==
0
]
=
float
(
'-inf'
)
x
=
torch
.
arange
(
N_CTX
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
M2
=
((
x
[:,
None
]
-
x
[
None
,
:]).
abs
()
<
window
)[
None
,
None
,
:]
p
=
torch
.
where
(
M2
,
p1
,
p2
)
p
=
torch
.
softmax
(
p
.
float
(),
dim
=-
1
).
half
()
ref_out
=
torch
.
matmul
(
p
,
v
)
return
ref_out
def
torch_attention2
(
query_states1
,
query_states2
,
key_states1
,
key_states2
,
value_states
,
causal
,
sm_scale
,
window
):
query_states1
=
query_states1
.
squeeze
(
0
).
contiguous
()
query_states2
=
query_states2
.
squeeze
(
0
).
contiguous
()
key_states1
=
key_states1
.
squeeze
(
0
).
contiguous
()
key_states2
=
key_states2
.
squeeze
(
0
).
contiguous
()
value_states
=
value_states
.
squeeze
(
0
).
contiguous
()
attn_weights1
=
torch
.
matmul
(
query_states1
,
key_states1
.
transpose
(
1
,
2
))
*
sm_scale
attn_weights2
=
torch
.
matmul
(
query_states2
,
key_states2
.
transpose
(
1
,
2
))
*
sm_scale
position_ids
=
torch
.
arange
(
query_states1
.
shape
[
1
],
device
=
query_states1
.
device
).
unsqueeze
(
0
)
rectified_mask
=
(
position_ids
[:,
-
N_CTX
:,
None
]
-
position_ids
[:,
None
]).
abs
()
<
window
attn_weights
=
torch
.
where
(
rectified_mask
,
attn_weights1
,
attn_weights2
)
if
causal
:
tgt_len
=
attn_weights
.
shape
[
-
1
]
dtype
=
attn_weights
.
dtype
device
=
attn_weights
.
device
mask
=
torch
.
full
((
tgt_len
,
tgt_len
),
torch
.
finfo
(
dtype
).
min
,
device
=
device
)
mask_cond
=
torch
.
arange
(
mask
.
size
(
-
1
),
device
=
device
)
mask
.
masked_fill_
(
mask_cond
<
(
mask_cond
+
1
).
view
(
mask
.
size
(
-
1
),
1
),
0
)
mask
=
mask
.
to
(
dtype
)
attn_weights
=
attn_weights
+
mask
# upcast attention to fp32
attn_weights
=
torch
.
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states1
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
return
attn_output
q1
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
torch
.
float16
,
device
=
'cuda'
).
normal_
(
mean
=
0.
,
std
=
0.5
).
contiguous
()
q2
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
torch
.
float16
,
device
=
'cuda'
).
normal_
(
mean
=
0.
,
std
=
0.5
).
contiguous
()
k1
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
torch
.
float16
,
device
=
'cuda'
).
normal_
(
mean
=
0.
,
std
=
0.5
).
contiguous
()
k2
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
torch
.
float16
,
device
=
'cuda'
).
normal_
(
mean
=
0.
,
std
=
0.5
).
contiguous
()
v
=
torch
.
empty
((
Z
,
H
,
N_CTX
,
D_HEAD
),
dtype
=
torch
.
float16
,
device
=
'cuda'
).
normal_
(
mean
=
0.
,
std
=
0.5
).
contiguous
()
# q1 = torch.load('/workspace/GitProjects/lmdeploy/q1.pt',
# map_location='cuda').contiguous()
# q2 = torch.load('/workspace/GitProjects/lmdeploy/q2.pt',
# map_location='cuda').contiguous()
# k1 = torch.load('/workspace/GitProjects/lmdeploy/k1.pt',
# map_location='cuda').contiguous()
# k2 = torch.load('/workspace/GitProjects/lmdeploy/k2.pt',
# map_location='cuda').contiguous()
# v = torch.load('/workspace/GitProjects/lmdeploy/v.pt',
# map_location='cuda').contiguous()
torch_output
=
torch_attention
(
q1
,
q2
,
k1
,
k2
,
v
,
True
,
sm_scale
,
WINDOW
)
torch_output2
=
torch_attention2
(
q1
,
q2
,
k1
,
k2
,
v
,
True
,
sm_scale
,
WINDOW
)
assert
torch
.
allclose
(
torch_output
,
torch_output2
,
atol
=
1e-2
,
rtol
=
0
)
for
_
in
range
(
100
):
triton_output
=
rerope_attention_fwd
(
q1
,
q2
,
k1
,
k2
,
v
,
True
,
sm_scale
,
WINDOW
)
assert
torch
.
allclose
(
torch_output
,
triton_output
,
atol
=
2e-2
,
rtol
=
0
)
is
True
def
f
(
fn
,
q1
,
q2
,
k1
,
k2
,
v
,
sm_scale
,
window
):
fn
(
q1
,
q2
,
k1
,
k2
,
v
,
True
,
sm_scale
,
window
)
t0
=
benchmark
.
Timer
(
stmt
=
'f(fn, q1, q2, k1, k2, v, sm_scale, window)'
,
globals
=
{
'f'
:
f
,
'fn'
:
torch_attention2
,
'q1'
:
q1
,
'q2'
:
q2
,
'k1'
:
k1
,
'k2'
:
k2
,
'v'
:
v
,
'sm_scale'
:
sm_scale
,
'window'
:
WINDOW
},
num_threads
=
torch
.
get_num_threads
())
print
(
t0
.
timeit
(
20
))
import
time
begin
=
time
.
time
()
LOOP
=
100
for
_
in
range
(
LOOP
):
rerope_attention_fwd
(
q1
,
q2
,
k1
,
k2
,
v
,
True
,
sm_scale
,
WINDOW
)
print
(
time
.
time
()
-
begin
)
test_rerope
()
lmdeploy/pytorch/kernels/rms_norm.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
triton
import
triton.language
as
tl
from
torch
import
Tensor
from
triton.runtime.jit
import
get_cuda_stream
@
triton
.
jit
def
rms_norm_kernel
(
input
,
weight
,
output
,
input_row_stride
,
n_cols
,
eps
,
N_COLS
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
):
"""rms norm kernel."""
prog_id
=
tl
.
program_id
(
0
)
offsets
=
tl
.
arange
(
0
,
BLOCK_N
)
w
=
tl
.
load
(
weight
+
offsets
,
mask
=
offsets
<
n_cols
)
x_ptr
=
input
+
prog_id
*
input_row_stride
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
offsets
<
n_cols
)
xf
=
x
.
to
(
tl
.
float32
)
var
=
tl
.
sum
(
xf
*
xf
,
0
)
*
float
(
1.0
/
N_COLS
)
out
=
xf
/
tl
.
sqrt
(
var
+
eps
)
out
=
(
w
*
out
).
to
(
x
.
dtype
)
out_ptr
=
output
+
prog_id
*
input_row_stride
tl
.
store
(
out_ptr
+
offsets
,
out
,
mask
=
offsets
<
n_cols
)
@
torch
.
inference_mode
()
def
rms_norm
(
hidden_states
:
Tensor
,
weight
:
Tensor
,
eps
:
float
=
1e-6
):
"""rms norm."""
def
_kernel_meta
():
device
=
hidden_states
.
device
device_idx
=
device
.
index
device_type
=
device
.
type
stream
=
get_cuda_stream
(
device_idx
)
return
dict
(
device
=
device
,
device_type
=
device_type
,
stream
=
stream
)
feat_size
=
weight
.
shape
[
0
]
seq_len
=
hidden_states
.
numel
()
//
hidden_states
.
size
(
-
1
)
input_stride
=
hidden_states
.
stride
(
-
2
)
BLOCK_N
=
triton
.
next_power_of_2
(
feat_size
)
out
=
torch
.
empty_like
(
hidden_states
)
kernel_meta
=
_kernel_meta
()
grid
=
(
seq_len
,
)
rms_norm_kernel
[
grid
](
hidden_states
,
weight
,
out
,
input_stride
,
feat_size
,
eps
,
feat_size
,
BLOCK_N
,
num_warps
=
4
,
num_stages
=
2
,
**
kernel_meta
)
return
out
if
__name__
==
'__main__'
:
import
time
def
torch_forward
(
hidden_states
,
weight
,
variance_epsilon
=
1e-6
):
"""pytorch forward."""
input_dtype
=
hidden_states
.
dtype
hidden_states
=
hidden_states
.
to
(
torch
.
float32
)
variance
=
hidden_states
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
variance_epsilon
)
return
weight
*
hidden_states
.
to
(
input_dtype
)
def
test_rms_norm
(
bsz
,
ctx_len
,
feat_len
,
dtype
):
"""test rms norm."""
input
=
torch
.
empty
((
bsz
,
ctx_len
,
feat_len
),
dtype
=
dtype
,
device
=
'cuda'
).
normal_
(
mean
=
0.
,
std
=
0.5
).
contiguous
()
weight
=
torch
.
empty
((
feat_len
),
dtype
=
dtype
,
device
=
'cuda'
).
normal_
(
mean
=
0.
,
std
=
0.5
).
contiguous
()
triton_output
=
rms_norm
(
hidden_states
=
input
,
weight
=
weight
)
torch_output
=
torch_forward
(
hidden_states
=
input
,
weight
=
weight
)
assert
torch
.
allclose
(
torch_output
,
triton_output
,
atol
=
1e-2
,
rtol
=
0
)
N_REPEATS
=
20
t0
=
time
.
time
()
for
_
in
range
(
N_REPEATS
):
torch_forward
(
hidden_states
=
input
,
weight
=
weight
)
t1
=
time
.
time
()
for
_
in
range
(
N_REPEATS
):
rms_norm
(
hidden_states
=
input
,
weight
=
weight
)
t2
=
time
.
time
()
torch_cost
=
(
t1
-
t0
)
/
N_REPEATS
*
1000
triton_cost
=
(
t2
-
t1
)
/
N_REPEATS
*
1000
print
(
'input {} weight {} dtype {}
\n
torch {:.3f} triton {:.3f} (ms)
\n
'
.
format
(
input
.
shape
,
weight
.
shape
,
dtype
,
torch_cost
,
triton_cost
))
test_rms_norm
(
1
,
8128
,
5120
,
torch
.
float16
)
test_rms_norm
(
1
,
8128
,
5120
,
torch
.
float32
)
test_rms_norm
(
1
,
992
,
128
,
torch
.
float16
)
test_rms_norm
(
1
,
65537
,
128
,
torch
.
float32
)
lmdeploy/pytorch/kernels/w8a8_triton_kernels.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
def
per_channel_quant
(
x
,
n_bits
,
dtype
):
"""Quantize the input tensor 'x' channel-wise using the given number of
bits.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be a
2-dimensional tensor.
n_bits (int): The number of bits to use for quantization.
dtype (torch.dtype): The data type to which the quantized tensor should
be converted.
Returns:
tuple: A tuple containing two items -- the quantized tensor and
the scale used for quantization.
"""
assert
x
.
ndim
==
2
x
=
x
.
to
(
torch
.
float32
)
x_absmax
=
x
.
view
(
x
.
shape
[
0
],
-
1
).
abs
().
max
(
dim
=
1
,
keepdim
=
True
)[
0
]
q_max
=
2
**
(
n_bits
-
1
)
-
1
q_min
=
-
2
**
(
n_bits
-
1
)
scale
=
x_absmax
/
(
2
**
(
n_bits
-
1
)
-
1
)
x_q
=
torch
.
round
(
x
/
scale
).
clamp
(
q_min
,
q_max
).
to
(
dtype
)
return
x_q
,
scale
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
16
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
256
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
32
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
)
],
key
=
[
'M'
,
'N'
,
'K'
],
)
@
triton
.
jit
def
_linear
(
A
,
B
,
C
,
M
,
N
,
K
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
rms_scale_ptr
,
linear_scale_ptr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B`, and store the result in output
tensor `C`.
The function applies auto-tuning for optimal performance and uses Just-in-
Time compilation.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
accumulator
+=
tl
.
dot
(
a
,
b
)
a_ptrs
+=
BLOCK_K
*
stride_ak
b_ptrs
+=
BLOCK_K
*
stride_bk
c
=
accumulator
.
to
(
tl
.
float32
)
rms_scale
=
tl
.
load
(
rms_scale_ptr
+
offs_am
)[:,
None
]
linear_scale
=
tl
.
load
(
linear_scale_ptr
+
offs_bn
)[
None
,
:]
c
=
c
*
rms_scale
*
linear_scale
offs_cm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_cn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
16
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
256
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
32
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
},
num_stages
=
4
,
num_warps
=
4
)
],
key
=
[
'M'
,
'N'
,
'K'
],
)
@
triton
.
jit
def
_linear_add
(
A
,
B
,
C
,
residual_ptr
,
M
,
N
,
K
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
rms_scale_ptr
,
linear_scale_ptr
,
):
"""Triton-accelerated function used to perform a linear operation (dot
product) on input tensors `A` and `B`, with addition of residual.
The result is stored in tensor `C`. The function applies auto-tuning for
optimal performance and uses Just-in-Time compilation.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
accumulator
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_K
,
other
=
0
)
accumulator
+=
tl
.
dot
(
a
,
b
)
a_ptrs
+=
BLOCK_K
*
stride_ak
b_ptrs
+=
BLOCK_K
*
stride_bk
c
=
accumulator
.
to
(
tl
.
float32
)
rms_scale
=
tl
.
load
(
rms_scale_ptr
+
offs_am
)[:,
None
]
linear_scale
=
tl
.
load
(
linear_scale_ptr
+
offs_bn
)[
None
,
:]
c
=
c
*
rms_scale
*
linear_scale
c
=
c
.
to
(
residual_ptr
.
dtype
.
element_ty
)
offs_cm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_cn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
residual_ptrs
=
(
residual_ptr
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:])
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
residual
=
tl
.
load
(
residual_ptrs
,
mask
=
c_mask
,
other
=
0.
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
tl
.
store
(
c_ptrs
,
c
+
residual
,
mask
=
c_mask
)
def
matmul_kernel_dynamic_quant
(
a
,
b
,
rms_scale
,
linear_scale
,
residual
=
None
,
bias
=
None
,
output_dtype
=
torch
.
float16
):
"""This function performs matrix multiplication with dynamic quantization.
It takes two input tensors `a` and `b`, scales them with `rms_scale` and
`linear_scale`, and optionally adds a `residual` tensor and a `bias`. The
output is returned in the specified `output_dtype`.
"""
assert
a
.
shape
[
-
1
]
==
b
.
shape
[
-
1
]
assert
b
.
ndim
==
2
and
b
.
is_contiguous
()
b
=
b
.
t
()
# (K, N)
M
=
a
.
numel
()
//
a
.
shape
[
-
1
]
K
,
N
=
b
.
shape
c_shape
=
a
.
shape
[:
-
1
]
+
(
N
,
)
if
residual
is
not
None
:
assert
residual
.
shape
==
c_shape
assert
residual
.
is_contiguous
()
c
=
torch
.
empty
(
c_shape
,
device
=
a
.
device
,
dtype
=
output_dtype
)
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
)
if
residual
is
not
None
:
_linear_add
[
grid
](
a
,
b
,
c
,
residual
,
M
,
N
,
K
,
a
.
stride
(
-
2
),
a
.
stride
(
-
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
-
2
),
c
.
stride
(
-
1
),
GROUP_SIZE_M
=
8
,
rms_scale_ptr
=
rms_scale
,
linear_scale_ptr
=
linear_scale
)
else
:
_linear
[
grid
](
a
,
b
,
c
,
M
,
N
,
K
,
a
.
stride
(
-
2
),
a
.
stride
(
-
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
-
2
),
c
.
stride
(
-
1
),
GROUP_SIZE_M
=
8
,
rms_scale_ptr
=
rms_scale
,
linear_scale_ptr
=
linear_scale
)
if
bias
is
not
None
:
c
+=
bias
return
c
@
triton
.
jit
def
_per_token_quant_int8
(
y_ptr
,
y_q_ptr
,
y_s_ptr
,
y_stride
,
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token quantization on a
tensor.
This function converts the tensor values into signed 8-bit integers.
"""
# Map the program id to the row of X and Y it should compute.
row
=
tl
.
program_id
(
0
)
y_ptr
+=
row
*
y_stride
y_q_ptr
+=
row
*
y_stride
y_s_ptr
+=
row
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
127
y_q
=
tl
.
maximum
(
tl
.
minimum
(
tl
.
math
.
round
(
y
/
y_s
),
127
),
-
128
).
to
(
tl
.
int8
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_quant_int8
(
x
,
eps
):
"""Function to perform per-token quantization on an input tensor `x`.
It converts the tensor values into signed 8-bit integers and returns the
quantized tensor along with the scaling factor used for quantization.
"""
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
M
=
x
.
numel
()
//
x
.
shape
[
-
1
]
N
=
x
.
shape
[
-
1
]
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
# enqueue kernel
_per_token_quant_int8
[(
M
,
)](
x
,
x_q
,
x_s
,
x
.
stride
(
-
2
),
N
,
eps
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
)
return
x_q
,
x_s
@
triton
.
jit
def
_rms_norm_fwd_fused_dynamic_symmetric
(
X
,
# pointer to the input
Y
,
# pointer to the output
W
,
# pointer to the weights
Scale
,
# pointer to the scales of the output activation
stride
,
# how much to increase the pointer when moving by 1 row
N
,
# number of columns in X
eps
,
# epsilon to avoid division by zero
BLOCK_SIZE
:
tl
.
constexpr
,
):
"""A Triton kernel that calculates Root Mean Square (RMS) normalization
with fused dynamic symmetric quantization."""
row
=
tl
.
program_id
(
0
)
Y
+=
row
*
stride
X
+=
row
*
stride
_var
=
tl
.
zeros
([
BLOCK_SIZE
],
dtype
=
tl
.
float32
)
cols
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
cols
<
N
,
other
=
0.
).
to
(
tl
.
float32
)
_var
+=
x
*
x
var
=
tl
.
sum
(
_var
,
axis
=
0
)
/
N
rstd
=
1
/
tl
.
sqrt
(
var
+
eps
)
cols
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
cols
<
N
w
=
tl
.
load
(
W
+
cols
,
mask
=
mask
)
x
=
tl
.
load
(
X
+
cols
,
mask
=
mask
,
other
=
0.
).
to
(
tl
.
float32
)
x_hat
=
x
*
rstd
y
=
x_hat
*
w
scale
=
tl
.
max
(
tl
.
abs
(
y
)).
to
(
tl
.
float32
)
/
127
tl
.
store
(
Scale
+
row
,
scale
)
y
=
tl
.
math
.
round
(
y
/
scale
)
y
=
tl
.
minimum
(
y
,
127
)
y
=
tl
.
maximum
(
y
,
-
128
)
tl
.
store
(
Y
+
cols
,
y
,
mask
=
mask
)
def
rms_norm_dynamic_quant
(
x
,
w
,
eps
):
"""Performs RMS normalization with dynamic quantization.
The function reshapes the input tensor `x`, creates an empty tensor `y`
with the same shape as `x`, and calculates RMS normalization on the
reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`.
"""
x_arg
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
y
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
M
,
K
=
x_arg
.
shape
MAX_FUSED_SIZE
=
65536
//
x
.
element_size
()
BLOCK_SIZE
=
min
(
MAX_FUSED_SIZE
,
triton
.
next_power_of_2
(
K
))
if
K
>
BLOCK_SIZE
:
raise
RuntimeError
(
"This rms norm doesn't support feature dim >= 64KB."
)
num_warps
=
min
(
max
(
BLOCK_SIZE
//
256
,
1
),
8
)
scale
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
1
,
),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
_rms_norm_fwd_fused_dynamic_symmetric
[(
M
,
)](
x_arg
,
y
,
w
,
scale
,
x_arg
.
stride
(
0
),
K
,
eps
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
num_warps
,
)
return
y
,
scale
def
test_rms_and_linear
(
x
,
rms_weight
,
linear_weight
,
dtype
=
torch
.
float16
,
eps
=
1e-5
):
"""Test quantized rms norm and quantized linear layer."""
def
rms_norm_torch
(
x
,
w
,
eps
):
variance
=
x
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
eps
)
return
w
*
x
def
linear_torch
(
x
,
b
):
return
F
.
linear
(
x
,
b
)
linear_weight_quant
,
linear_scale
=
per_channel_quant
(
linear_weight
,
8
,
torch
.
int8
)
rms_out
,
rms_scale
=
rms_norm_dynamic_quant
(
x
,
rms_weight
,
eps
)
assert
rms_out
.
shape
==
x
.
shape
and
rms_scale
.
shape
[:
-
1
]
==
x
.
shape
[:
-
1
]
linear_out
=
matmul_kernel_dynamic_quant
(
rms_out
,
linear_weight_quant
,
rms_scale
,
linear_scale
,
output_dtype
=
dtype
)
rms_out_torch
=
rms_norm_torch
(
x
,
rms_weight
,
eps
).
half
()
linear_out_torch
=
linear_torch
(
rms_out_torch
,
linear_weight
)
print
(
f
'linear_out.abs().mean() =
{
linear_out
.
abs
().
mean
()
}
'
)
print
(
f
'linear_out_torch.abs().mean() =
{
linear_out_torch
.
abs
().
mean
()
}
'
)
print
(
'perchannel error: '
,
(
linear_out
-
linear_out_torch
).
abs
().
mean
())
cos
=
torch
.
nn
.
CosineSimilarity
(
0
)
print
(
'Output cos'
,
cos
(
linear_out
.
flatten
().
to
(
torch
.
float32
),
linear_out_torch
.
flatten
().
to
(
torch
.
float32
)))
def
test_per_token_quant
(
x
,
eps
):
"""Test per-token quantization."""
def
per_token_quant_int8_torch
(
x
,
eps
):
_absmax
=
torch
.
clamp
(
x
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
],
min
=
eps
)
x_s
=
_absmax
/
127
x_q
=
torch
.
clamp
((
x
/
x_s
).
round
(),
min
=-
128
,
max
=
127
)
return
x_q
,
x_s
x_q
,
x_s
=
per_token_quant_int8
(
x
,
eps
)
x_q_torch
,
x_s_torch
=
per_token_quant_int8_torch
(
x
,
eps
)
assert
x_q
.
shape
==
x_q_torch
.
shape
and
x_s
.
shape
==
x_s_torch
.
shape
cos
=
torch
.
nn
.
CosineSimilarity
(
0
)
print
(
'x_q cos'
,
cos
(
x_q
.
flatten
().
to
(
torch
.
float32
),
x_q_torch
.
flatten
().
to
(
torch
.
float32
)))
print
(
'x_s cos'
,
cos
(
x_s
.
flatten
().
to
(
torch
.
float32
),
x_s_torch
.
flatten
().
to
(
torch
.
float32
)))
@
triton
.
testing
.
perf_report
(
triton
.
testing
.
Benchmark
(
x_names
=
[
'M'
],
x_vals
=
[
1
,
16
,
32
,
64
,
128
,
256
]
+
[
512
*
i
*
2
for
i
in
range
(
1
,
17
)],
line_arg
=
'provider'
,
line_vals
=
[
'int8_dynamic_triton_op'
,
'float_torch'
],
line_names
=
[
'int8_dynamic_triton_op'
,
'float_torch'
],
styles
=
[(
'blue'
,
'-'
),
(
'green'
,
'-'
),
(
'orange'
,
'-'
),
(
'yellow'
,
'-'
),
(
'yellow'
,
'-'
)],
ylabel
=
'GB/s'
,
plot_name
=
'forward'
,
args
=
{
'dtype'
:
torch
.
float16
,
}))
def
bench_rms_and_linear
(
M
,
dtype
,
provider
,
eps
=
1e-5
,
device
=
'cuda'
):
def
rms_norm_torch
(
x
,
w
,
eps
):
variance
=
x
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
x
=
x
*
torch
.
rsqrt
(
variance
+
eps
)
return
w
*
x
def
linear_torch
(
x
,
b
):
return
F
.
linear
(
x
,
b
)
N
=
4096
K
=
4096
x_shape
=
(
M
,
K
)
rms_w_shape
=
(
x_shape
[
-
1
],
)
rms_weight
=
torch
.
randn
(
rms_w_shape
,
dtype
=
dtype
,
device
=
'cuda'
,
requires_grad
=
True
)
x
=
torch
.
randn
(
x_shape
,
dtype
=
dtype
,
device
=
'cuda'
)
linear_weight
=
torch
.
randn
((
N
,
K
),
dtype
=
dtype
,
device
=
'cuda'
,
requires_grad
=
True
)
linear_weight_quant
,
linear_scale
=
per_channel_quant
(
linear_weight
,
8
,
torch
.
int8
)
alpha
=
max
(
x
.
max
().
abs
(),
x
.
min
().
abs
())
rms_scale
=
alpha
/
127
if
provider
==
'int8_dynamic_triton_op'
:
rms_out
,
rms_scale
=
rms_norm_dynamic_quant
(
x
,
rms_weight
,
eps
)
def
y_fwd
():
matmul_kernel_dynamic_quant
(
rms_out
,
linear_weight_quant
,
rms_scale
,
linear_scale
,
output_dtype
=
dtype
)
elif
provider
==
'float_torch'
:
rms_out_torch
=
rms_norm_torch
(
x
,
rms_weight
,
eps
).
half
()
def
y_fwd
():
linear_torch
(
rms_out_torch
,
linear_weight
)
quantiles
=
[
0.5
,
0.2
,
0.8
]
ms
,
min_ms
,
max_ms
=
triton
.
testing
.
do_bench
(
y_fwd
,
quantiles
=
quantiles
,
rep
=
500
)
return
ms
,
max_ms
,
min_ms
if
__name__
==
'__main__'
:
torch
.
manual_seed
(
0
)
dtype
=
torch
.
float16
# test (bs, seq_len, dim) x (dim, out_dim)
x
=
torch
.
randn
((
2
,
2048
,
4096
),
dtype
=
dtype
,
device
=
'cuda'
)
rms_weight
=
torch
.
randn
((
4096
,
),
dtype
=
dtype
,
device
=
'cuda'
,
requires_grad
=
True
)
linear_weight
=
torch
.
randn
((
11008
,
4096
),
dtype
=
dtype
,
device
=
'cuda'
,
requires_grad
=
True
)
test_rms_and_linear
(
x
,
rms_weight
,
linear_weight
)
# test (M, K) x (K, N)
x
=
torch
.
randn
((
4
,
4096
),
dtype
=
dtype
,
device
=
'cuda'
)
rms_weight
=
torch
.
randn
((
4096
,
),
dtype
=
dtype
,
device
=
'cuda'
,
requires_grad
=
True
)
linear_weight
=
torch
.
randn
((
2048
,
4096
),
dtype
=
dtype
,
device
=
'cuda'
,
requires_grad
=
True
)
test_rms_and_linear
(
x
,
rms_weight
,
linear_weight
)
# test per-token quant
x
=
torch
.
randn
((
4
,
2048
,
4096
),
dtype
=
dtype
,
device
=
'cuda'
)
eps
=
1e-7
test_per_token_quant
(
x
,
eps
)
bench_rms_and_linear
.
run
(
print_data
=
True
)
lmdeploy/pytorch/messages.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
enum
import
time
from
copy
import
deepcopy
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Dict
,
List
import
torch
from
torch
import
Tensor
from
lmdeploy.messages
import
EngineGenerationConfig
from
lmdeploy.utils
import
get_logger
from
.block
import
LogicalTokenBlocks
logger
=
get_logger
(
'lmdeploy'
)
@
dataclass
class
SamplingParam
:
"""Sampling parameter."""
top_p
:
float
=
1.0
top_k
:
int
=
1
temperature
:
float
=
0.8
repetition_penalty
:
float
=
1.0
ignore_eos
:
bool
=
False
random_seed
:
int
=
None
stop_words
:
List
[
int
]
=
field
(
default_factory
=
list
)
bad_words
:
List
[
int
]
=
field
(
default_factory
=
list
)
max_new_tokens
:
int
=
512
min_new_tokens
:
int
=
0
def
logical_sampling_param
(
self
):
"""create a SamplingParam for logical sampling."""
return
SamplingParam
(
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
temperature
=
self
.
temperature
,
repetition_penalty
=
self
.
repetition_penalty
,
ignore_eos
=
self
.
ignore_eos
,
random_seed
=
self
.
random_seed
,
bad_words
=
self
.
bad_words
)
@
classmethod
def
from_gen_config
(
self
,
gen_config
:
EngineGenerationConfig
):
"""from gen config."""
min_new_tokens
=
gen_config
.
min_new_tokens
or
0
stop_words
=
gen_config
.
stop_words
or
[]
bad_words
=
gen_config
.
bad_words
or
[]
if
gen_config
.
ignore_eos
:
bad_words
+=
stop_words
top_k
=
gen_config
.
top_k
top_p
=
gen_config
.
top_p
temperature
=
gen_config
.
temperature
repetition_penalty
=
gen_config
.
repetition_penalty
max_new_tokens
=
gen_config
.
max_new_tokens
if
top_k
<=
0
:
logger
.
warning
(
'`top_k` has to be a strictly'
f
' positive value, but is
{
top_k
}
'
)
top_k
=
1
if
top_p
<
0
or
top_p
>
1.0
:
logger
.
warning
(
'`top_p` has to be a float > 0 and < 1'
f
' but is
{
top_p
}
'
)
top_p
=
1.0
if
temperature
<=
0
:
logger
.
warning
(
'`temperature` has to be a strictly'
f
' positive value, but is
{
temperature
}
'
)
temperature
=
1.0
if
repetition_penalty
<=
0
:
logger
.
warning
(
'`repetition_penalty` has to be a strictly'
f
' positive value, but is
{
repetition_penalty
}
'
)
repetition_penalty
=
1.0
if
max_new_tokens
<
0
:
logger
.
warning
(
'`max_new_tokens` has to be a strictly'
f
' positive value, but is
{
max_new_tokens
}
'
)
max_new_tokens
=
512
if
min_new_tokens
<
0
or
min_new_tokens
>
max_new_tokens
:
logger
.
warning
(
'`min_new_tokens` has to be '
'a int >=0 and <= `max_new_tokens`,'
f
' but is
{
min_new_tokens
}
'
)
min_new_tokens
=
0
return
SamplingParam
(
top_p
=
top_p
,
top_k
=
top_k
,
temperature
=
temperature
,
repetition_penalty
=
repetition_penalty
,
ignore_eos
=
gen_config
.
ignore_eos
,
random_seed
=
gen_config
.
random_seed
,
stop_words
=
stop_words
,
bad_words
=
bad_words
,
max_new_tokens
=
max_new_tokens
,
min_new_tokens
=
min_new_tokens
)
class
MessageStatus
(
enum
.
Enum
):
"""Status of a sequence."""
WAITING
=
enum
.
auto
()
RUNNING
=
enum
.
auto
()
STOPPED
=
enum
.
auto
()
ENDED
=
enum
.
auto
()
ABORTED
=
enum
.
auto
()
_SEQ_COUNT
=
0
def
_new_msg_id
():
"""get a new message id."""
global
_SEQ_COUNT
seq_id
=
_SEQ_COUNT
_SEQ_COUNT
+=
1
return
seq_id
class
SchedulerSession
:
"""Scheduler session."""
def
__init__
(
self
,
session_id
:
int
,
block_size
:
int
)
->
None
:
self
.
session_id
=
session_id
self
.
block_size
=
block_size
self
.
status
:
MessageStatus
=
MessageStatus
.
RUNNING
self
.
sequences
:
Dict
[
int
,
SchedulerSequence
]
=
dict
()
def
add_sequence
(
self
,
token_ids
:
Tensor
,
sampling_param
:
SamplingParam
=
None
,
adapter_name
:
str
=
None
,
return_logits
:
bool
=
False
)
->
'SchedulerSequence'
:
"""Add a new message."""
if
not
isinstance
(
token_ids
,
Tensor
):
token_ids
=
torch
.
tensor
(
token_ids
)
if
token_ids
.
dim
()
==
0
:
token_ids
=
token_ids
.
unsqueeze
(
0
)
if
sampling_param
is
None
:
sampling_param
=
SamplingParam
()
seq
=
SchedulerSequence
(
seq_id
=
_new_msg_id
(),
token_ids
=
token_ids
,
session
=
self
,
block_size
=
self
.
block_size
,
status
=
MessageStatus
.
WAITING
,
num_new_tokens
=
0
,
sampling_param
=
sampling_param
,
adapter_name
=
adapter_name
,
arrive_time
=
time
.
time
(),
return_logits
=
return_logits
)
self
.
sequences
[
seq
.
seq_id
]
=
seq
return
seq
def
fork_sequence
(
self
,
token_ids
:
Tensor
,
seq
:
'SchedulerSequence'
,
sampling_param
:
SamplingParam
=
None
)
->
'SchedulerSequence'
:
"""Fork a new message from exist message."""
if
sampling_param
is
None
:
sampling_param
=
deepcopy
(
seq
.
sampling_param
)
if
not
isinstance
(
token_ids
,
Tensor
):
token_ids
=
torch
.
tensor
(
token_ids
)
if
token_ids
.
dim
()
==
0
:
token_ids
=
token_ids
.
unsqueeze
(
0
)
assert
seq
.
session
==
self
new_msg
=
SchedulerSequence
(
seq_id
=
_new_msg_id
(),
token_ids
=
token_ids
,
session
=
self
,
block_size
=
self
.
block_size
,
history_token_ids
=
seq
.
history_token_ids
.
copy
(),
num_new_tokens
=
0
,
sampling_param
=
sampling_param
,
status
=
seq
.
status
,
logical_blocks
=
seq
.
logical_blocks
.
clone
(),
adapter_name
=
seq
.
adapter_name
,
arrive_time
=
time
.
time
(),
meta
=
deepcopy
(
seq
.
meta
),
return_logits
=
seq
.
return_logits
,
random_offsets
=
seq
.
random_offsets
+
1
)
self
.
sequences
[
new_msg
.
seq_id
]
=
new_msg
return
new_msg
@
dataclass
class
SchedulerSequence
:
"""Scheduler message."""
seq_id
:
int
token_ids
:
Tensor
session
:
SchedulerSession
block_size
:
int
history_token_ids
:
list
=
field
(
default_factory
=
list
)
num_new_tokens
:
int
=
0
sampling_param
:
SamplingParam
=
field
(
default_factory
=
SamplingParam
)
status
:
MessageStatus
=
MessageStatus
.
WAITING
logical_blocks
:
LogicalTokenBlocks
=
field
(
default_factory
=
LogicalTokenBlocks
)
sender_id
:
int
=
-
1
req_id
:
int
=
-
1
adapter_name
:
str
=
None
arrive_time
:
float
=
0.0
meta
:
Any
=
None
return_logits
:
bool
=
False
random_offsets
:
int
=
0
@
property
def
history_len
(
self
)
->
int
:
"""get history length."""
return
len
(
self
.
history_token_ids
)
@
property
def
session_id
(
self
)
->
int
:
"""get session id."""
return
self
.
session
.
session_id
def
num_all_tokens
(
self
)
->
int
:
"""num all tokens."""
return
len
(
self
.
token_ids
)
+
self
.
history_len
def
update_token_ids
(
self
,
token_ids
:
Tensor
,
update_history
:
bool
=
True
):
"""Update token ids, old token ids will be added to history."""
if
update_history
:
self
.
history_token_ids
+=
self
.
token_ids
.
tolist
()
if
not
isinstance
(
token_ids
,
Tensor
):
token_ids
=
self
.
token_ids
.
new_tensor
(
token_ids
)
if
token_ids
.
dim
()
==
0
:
token_ids
=
token_ids
.
unsqueeze
(
0
)
self
.
token_ids
=
token_ids
self
.
random_offsets
+=
1
self
.
arrive_time
=
time
.
time
()
def
set_step
(
self
,
step
:
int
):
"""set step."""
assert
step
<=
self
.
history_len
history_token_ids
=
torch
.
tensor
(
self
.
history_token_ids
,
dtype
=
torch
.
long
)
new_history_ids
=
self
.
history_token_ids
[:
step
]
new_token_ids
=
torch
.
cat
([
history_token_ids
[
step
:],
self
.
token_ids
])
self
.
history_token_ids
=
new_history_ids
self
.
token_ids
=
new_token_ids
lmdeploy/pytorch/modeling/__init__.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
lmdeploy/pytorch/modeling/convert_to_qmodules.py
0 → 100644
View file @
fe851fbc
# Copyright (c) OpenMMLab. All rights reserved.
import
torch.nn
as
nn
from
lmdeploy.pytorch.models
import
QLinear
,
QRMSNorm
LAYER_TYPE_MAP
=
{
'InternLMForCausalLM'
:
'InternLMDecoderLayer'
,
'InternLM2ForCausalLM'
:
'InternLM2DecoderLayer'
,
'QWenLMHeadModel'
:
'QWenBlock'
,
'BaiChuanForCausalLM'
:
'DecoderLayer'
,
'LlamaForCausalLM'
:
'LlamaDecoderLayer'
,
}
NORM_TYPE_MAP
=
{
'InternLMForCausalLM'
:
'InternLMRMSNorm'
,
'InternLM2ForCausalLM'
:
'InternLM2RMSNorm'
,
'QWenLMHeadModel'
:
'RMSNorm'
,
'BaiChuanForCausalLM'
:
'RMSNorm'
,
'LlamaForCausalLM'
:
'LlamaRMSNorm'
,
}
def
convert_decoder_layer
(
module
,
norm_type
):
"""Converts a given module's child layers from regular Linear or RMSNorm to
their Quantized versions (QLinear, QRMSNorm).
The conversion is done in place.
"""
for
name
,
child
in
module
.
named_children
():
if
isinstance
(
child
,
nn
.
Linear
):
new_child
=
QLinear
.
from_float
(
child
,
initialization
=
False
)
setattr
(
module
,
name
,
new_child
)
elif
type
(
child
).
__name__
==
norm_type
:
new_child
=
QRMSNorm
.
from_float
(
child
,
initialization
=
False
)
setattr
(
module
,
name
,
new_child
)
else
:
convert_decoder_layer
(
child
,
norm_type
)
def
convert
(
module
,
layer_type
,
norm_type
):
"""Recursively traverses through given PyTorch module and identifies child
layers that match the specified layer_type and norm_type for conversion to
their Quantized counterparts.
The conversion is done using the `convert_decoder_layer` function.
"""
for
child
in
module
.
children
():
if
type
(
child
).
__name__
==
layer_type
:
convert_decoder_layer
(
child
,
norm_type
)
else
:
convert
(
child
,
layer_type
,
norm_type
)
def
convert_to_qmodules
(
model
):
"""Convert all Linear and RMSNorm in the decoder layers of the model into
their Quantized versions (QLinear, QRMSNorm)."""
layer_type
=
LAYER_TYPE_MAP
[
type
(
model
).
__name__
]
norm_type
=
NORM_TYPE_MAP
[
type
(
model
).
__name__
]
convert
(
model
,
layer_type
,
norm_type
)
return
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment