Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
97870d3e
Commit
97870d3e
authored
Jan 14, 2026
by
MaYuhang
Browse files
issue/189: add inference server support to InfiniLM
parent
de3e6b95
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1981 additions
and
1 deletion
+1981
-1
README.md
README.md
+22
-0
python/infinilm/__init__.py
python/infinilm/__init__.py
+21
-1
python/infinilm/llm/__init__.py
python/infinilm/llm/__init__.py
+43
-0
python/infinilm/llm/cache_manager.py
python/infinilm/llm/cache_manager.py
+268
-0
python/infinilm/llm/llm.py
python/infinilm/llm/llm.py
+646
-0
python/infinilm/llm/request.py
python/infinilm/llm/request.py
+231
-0
python/infinilm/llm/sampling_params.py
python/infinilm/llm/sampling_params.py
+35
-0
python/infinilm/llm/scheduler.py
python/infinilm/llm/scheduler.py
+248
-0
python/infinilm/server/inference_server.py
python/infinilm/server/inference_server.py
+467
-0
No files found.
README.md
View file @
97870d3e
...
...
@@ -88,6 +88,28 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
python examples/jiuge.py
--nvidia
--model_path
=
/models/9G7B_MHA/
--backend
=
cpp
--tp
=
4
--batch_size
=
16
```
- 推理服务测试
- 启动推理服务
```
bash
python python/infinilm/server/inference_server.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] --model_path=
<path
/
to
/
model_dir
>
--max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH --tp=NDEV --temperature=TEMP --top_p=TOP_P --top_k=TOP_K --host=HOST --port=PORT
```
- 单卡示例:
```
bash
CUDA_VISIBLE_DEVICES=0 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1
```
- 多卡分布式示例:
```
bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=4 --temperature=1.0 --top_p=0.8 --top_k=1
```
- 测试推理服务性能:
```
bash
python scripts/test_perf.py --verbose
```
- 运行推理基准测试(C-Eval/MMLU)
```
bash
...
...
python/infinilm/__init__.py
View file @
97870d3e
from
.models
import
AutoLlamaModel
from
.
import
distributed
from
.
import
cache
from
.
import
llm
__all__
=
[
"AutoLlamaModel"
,
"distributed"
,
"cache"
]
from
.llm
import
(
LLM
,
AsyncLLMEngine
,
SamplingParams
,
RequestOutput
,
TokenOutput
,
)
__all__
=
[
"AutoLlamaModel"
,
"distributed"
,
"cache"
,
"llm"
,
# LLM classes
"LLM"
,
"AsyncLLMEngine"
,
"SamplingParams"
,
"RequestOutput"
,
"TokenOutput"
,
]
python/infinilm/llm/__init__.py
0 → 100644
View file @
97870d3e
"""
InfiniLM Engine - High-performance llm inference engine with batch generation and streaming support.
"""
from
infinilm.llm.sampling_params
import
SamplingParams
from
infinilm.llm.request
import
(
RequestStatus
,
FinishReason
,
RequestOutput
,
CompletionOutput
,
TokenOutput
,
InferenceRequest
,
)
from
infinilm.llm.llm
import
(
LLM
,
LLMEngine
,
AsyncLLMEngine
,
EngineConfig
,
)
from
infinilm.llm.scheduler
import
Scheduler
,
SchedulerOutput
from
infinilm.llm.cache_manager
import
BlockManager
,
Block
__all__
=
[
# Main classes
"LLM"
,
"AsyncLLMEngine"
,
"LLMEngine"
,
"EngineConfig"
,
# Parameters
"SamplingParams"
,
# Request and Output
"InferenceRequest"
,
"RequestOutput"
,
"CompletionOutput"
,
"TokenOutput"
,
"RequestStatus"
,
"FinishReason"
,
# Internal (for advanced use)
"Scheduler"
,
"SchedulerOutput"
,
"BlockManager"
,
"Block"
,
]
python/infinilm/llm/cache_manager.py
0 → 100644
View file @
97870d3e
"""
KV Cache Manager - Paged Attention block-based cache allocation and management.
"""
from
collections
import
deque
from
typing
import
List
,
Dict
,
Set
import
xxhash
import
numpy
as
np
class
Block
:
"""KV Cache Block with reference counting and hash-based reuse support."""
def
__init__
(
self
,
block_id
:
int
):
self
.
block_id
=
block_id
self
.
ref_count
=
0
self
.
hash
=
-
1
self
.
token_ids
:
List
[
int
]
=
[]
def
update
(
self
,
hash_value
:
int
,
token_ids
:
List
[
int
])
->
None
:
self
.
hash
=
hash_value
self
.
token_ids
=
token_ids
.
copy
()
def
reset
(
self
)
->
None
:
self
.
ref_count
=
1
self
.
hash
=
-
1
self
.
token_ids
=
[]
def
free
(
self
)
->
None
:
self
.
ref_count
=
0
self
.
hash
=
-
1
self
.
token_ids
=
[]
def
__repr__
(
self
)
->
str
:
return
f
"Block(id=
{
self
.
block_id
}
, ref=
{
self
.
ref_count
}
, hash=
{
self
.
hash
}
)"
class
BlockManager
:
"""Manages Paged KV Cache allocation with prefix caching support.
Features:
- Block allocation/deallocation with reference counting
- Hash-based prefix caching for token sequence reuse
- Slot mapping generation for physical-to-logical position mapping
"""
def
__init__
(
self
,
num_blocks
:
int
,
block_size
:
int
):
assert
(
num_blocks
>
0
and
block_size
>
0
),
"num_blocks and block_size must be positive"
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
blocks
:
List
[
Block
]
=
[
Block
(
i
)
for
i
in
range
(
num_blocks
)]
self
.
hash_to_block_id
:
Dict
[
int
,
int
]
=
{}
self
.
free_block_ids
:
deque
=
deque
(
range
(
num_blocks
))
self
.
used_block_ids
:
Set
[
int
]
=
set
()
self
.
req_block_ids
:
Set
[
int
]
=
set
()
def
reset_req_blocks
(
self
)
->
None
:
"""Move blocks from prefill stage to used blocks and update hash mappings."""
for
block_id
in
self
.
req_block_ids
:
self
.
used_block_ids
.
add
(
block_id
)
block
=
self
.
blocks
[
block_id
]
prefix_hash
=
block
.
hash
self
.
hash_to_block_id
[
prefix_hash
]
=
block_id
self
.
req_block_ids
.
clear
()
@
classmethod
def
compute_hash
(
cls
,
token_ids
:
List
[
int
],
prefix_hash
:
int
=
-
1
)
->
int
:
"""Compute hash for token sequence with optional prefix chaining."""
h
=
xxhash
.
xxh64
()
if
prefix_hash
!=
-
1
:
h
.
update
(
prefix_hash
.
to_bytes
(
8
,
"little"
))
h
.
update
(
np
.
array
(
token_ids
,
dtype
=
np
.
int32
).
tobytes
())
return
h
.
intdigest
()
def
_allocate_partial_block
(
self
,
block_id
:
int
)
->
Block
:
"""Allocate an incomplete block and add to used blocks."""
assert
block_id
in
self
.
free_block_ids
,
f
"Block
{
block_id
}
not in free list"
block
=
self
.
blocks
[
block_id
]
assert
block
.
ref_count
==
0
,
f
"Block
{
block_id
}
ref_count not zero"
block
.
reset
()
self
.
free_block_ids
.
remove
(
block_id
)
self
.
used_block_ids
.
add
(
block_id
)
return
block
def
_allocate_full_block
(
self
,
block_id
:
int
)
->
Block
:
"""Allocate a complete block and add to request blocks."""
assert
block_id
in
self
.
free_block_ids
,
f
"Block
{
block_id
}
not in free list"
block
=
self
.
blocks
[
block_id
]
assert
block
.
ref_count
==
0
,
f
"Block
{
block_id
}
ref_count not zero"
block
.
reset
()
self
.
free_block_ids
.
remove
(
block_id
)
self
.
req_block_ids
.
add
(
block_id
)
return
block
def
_deallocate_block
(
self
,
block_id
:
int
):
"""Deallocate a block and return it to free list."""
block
=
self
.
blocks
[
block_id
]
assert
(
block
.
ref_count
==
0
),
f
"Block
{
block_id
}
ref_count not zero, cannot deallocate"
if
block
.
hash
!=
-
1
and
self
.
hash_to_block_id
.
get
(
block
.
hash
)
==
block_id
:
del
self
.
hash_to_block_id
[
block
.
hash
]
block
.
free
()
self
.
used_block_ids
.
remove
(
block_id
)
self
.
free_block_ids
.
append
(
block_id
)
def
can_allocate
(
self
,
num_required_blocks
:
int
)
->
bool
:
return
len
(
self
.
free_block_ids
)
>=
num_required_blocks
def
allocate_blocks
(
self
,
token_ids
:
List
[
int
],
block_table
:
List
[
int
]
=
None
)
->
tuple
[
List
[
int
],
List
[
int
],
int
]:
"""Allocate cache blocks for new request with prefix caching support.
Args:
token_ids: Input token sequence
block_table: Existing block_table (for decode phase)
Returns:
Tuple of (block_table, slot_mapping, num_cached_tokens)
"""
if
block_table
is
None
:
block_table
=
[]
num_tokens
=
len
(
token_ids
)
num_blocks
=
(
num_tokens
+
self
.
block_size
-
1
)
//
self
.
block_size
slot_mapping
=
[]
num_cached_tokens
=
0
prefix_hash
=
-
1
cache_miss
=
False
for
block_idx
in
range
(
num_blocks
):
start_idx
=
block_idx
*
self
.
block_size
end_idx
=
min
(
start_idx
+
self
.
block_size
,
num_tokens
)
block_tokens
=
token_ids
[
start_idx
:
end_idx
]
# Only full blocks can be hashed for reuse
if
len
(
block_tokens
)
==
self
.
block_size
:
prefix_hash
=
self
.
compute_hash
(
block_tokens
,
prefix_hash
)
# Try to reuse existing block
if
not
cache_miss
:
cached_block_id
=
self
.
hash_to_block_id
.
get
(
prefix_hash
,
-
1
)
if
(
cached_block_id
!=
-
1
and
self
.
blocks
[
cached_block_id
].
token_ids
==
block_tokens
):
# Check if all tokens are cached
if
num_cached_tokens
+
self
.
block_size
==
len
(
token_ids
):
cache_miss
=
True
else
:
# Reuse successful
block
=
self
.
blocks
[
cached_block_id
]
block
.
ref_count
+=
1
block_table
.
append
(
cached_block_id
)
num_cached_tokens
+=
self
.
block_size
continue
else
:
cache_miss
=
True
else
:
prefix_hash
=
-
1
# Cannot reuse, allocate new block
if
not
self
.
free_block_ids
:
raise
RuntimeError
(
"No available cache blocks"
)
new_block_id
=
self
.
free_block_ids
[
0
]
if
prefix_hash
!=
-
1
:
block
=
self
.
_allocate_full_block
(
new_block_id
)
block
.
update
(
prefix_hash
,
block_tokens
)
else
:
block
=
self
.
_allocate_partial_block
(
new_block_id
)
block_table
.
append
(
new_block_id
)
# Generate slot_mapping
for
i
in
range
(
len
(
block_tokens
)):
slot_mapping
.
append
(
new_block_id
*
self
.
block_size
+
i
)
return
block_table
,
slot_mapping
,
num_cached_tokens
def
append_slot
(
self
,
block_table
:
List
[
int
],
num_tokens
:
int
,
total_token_ids
:
List
[
int
]
=
None
)
->
tuple
[
List
[
int
],
int
]:
"""Append slot for decode phase (generate one new token).
Args:
block_table: Current block_table
num_tokens: Current total token count (including newly generated token)
total_token_ids: All token sequence (for updating block hash)
Returns:
Tuple of (block_table, slot_id)
"""
assert
len
(
block_table
)
>
0
,
"block_table cannot be empty"
assert
num_tokens
>
0
,
"num_tokens must be greater than 0"
if
num_tokens
%
self
.
block_size
==
1
:
# Previous block is full, update its hash for future prefix caching
last_block_id
=
block_table
[
-
1
]
last_block
=
self
.
blocks
[
last_block_id
]
# Only update if block's token_ids is empty (avoid duplicate updates)
if
len
(
last_block
.
token_ids
)
==
0
:
block_start_idx
=
num_tokens
-
self
.
block_size
-
1
block_end_idx
=
num_tokens
-
1
block_tokens
=
total_token_ids
[
block_start_idx
:
block_end_idx
]
# Compute prefix_hash using previous block's hash if available
if
len
(
block_table
)
>
1
:
prev_block
=
self
.
blocks
[
block_table
[
-
2
]]
prefix_hash
=
prev_block
.
hash
else
:
prefix_hash
=
-
1
current_hash
=
self
.
compute_hash
(
block_tokens
,
prefix_hash
)
last_block
.
update
(
current_hash
,
block_tokens
)
self
.
hash_to_block_id
[
current_hash
]
=
last_block_id
# Need new block
if
not
self
.
free_block_ids
:
if
not
self
.
try_free_blocks
(
1
):
raise
RuntimeError
(
"No available cache blocks"
)
new_block_id
=
self
.
free_block_ids
[
0
]
self
.
_allocate_partial_block
(
new_block_id
)
block_table
.
append
(
new_block_id
)
# Calculate slot
last_block_id
=
block_table
[
-
1
]
offset
=
(
num_tokens
-
1
)
%
self
.
block_size
slot_id
=
last_block_id
*
self
.
block_size
+
offset
return
block_table
,
slot_id
def
free_blocks
(
self
,
block_table
:
List
[
int
]):
"""Decrease reference count for all blocks. Blocks with ref_count=0 are not
immediately freed to allow reuse."""
for
block_id
in
reversed
(
block_table
):
block
=
self
.
blocks
[
block_id
]
block
.
ref_count
-=
1
def
try_free_blocks
(
self
,
num_required
:
int
)
->
bool
:
"""Try to free blocks with ref_count=0."""
to_free
=
[
bid
for
bid
in
self
.
used_block_ids
if
self
.
blocks
[
bid
].
ref_count
==
0
]
for
block_id
in
to_free
:
self
.
_deallocate_block
(
block_id
)
if
self
.
can_allocate
(
num_required
):
return
True
return
self
.
can_allocate
(
num_required
)
def
get_num_free_blocks
(
self
)
->
int
:
return
len
(
self
.
free_block_ids
)
def
__repr__
(
self
):
return
(
f
"BlockManager(blocks=
{
self
.
num_blocks
}
, block_size=
{
self
.
block_size
}
, "
f
"free=
{
len
(
self
.
free_block_ids
)
}
, used=
{
len
(
self
.
used_block_ids
)
}
)"
)
python/infinilm/llm/llm.py
0 → 100644
View file @
97870d3e
"""
LLM Engine - Main interface for LLM inference.
This module provides:
- LLM class for batch generation (offline use)
- AsyncLLM class for asynchronous streaming (server use)
"""
import
time
import
uuid
import
logging
import
threading
from
typing
import
List
,
Optional
,
Union
,
AsyncIterator
from
dataclasses
import
dataclass
import
infinicore
from
infinilm.llm.request
import
(
InferenceRequest
,
RequestOutput
,
TokenOutput
,
FinishReason
,
)
from
infinilm.llm.sampling_params
import
SamplingParams
from
infinilm.llm.scheduler
import
Scheduler
from
infinilm.distributed
import
DistConfig
from
infinilm.infer_engine
import
InferEngine
from
infinilm.cache.cache
import
PagedKVCacheConfig
from
infinilm.modeling_utils
import
load_model_state_dict_by_file
from
transformers
import
AutoTokenizer
from
tokenizers
import
decoders
as
_dec
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
EngineConfig
:
"""Configuration for LLM Engine.
Attributes:
model_path: Path to the model directory.
device: Device type string ('cpu', 'cuda', 'mlu', etc.).
dtype: Data type string ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
"""
model_path
:
str
device
:
str
=
"cuda"
dtype
:
str
=
"float16"
tensor_parallel_size
:
int
=
1
max_batch_size
:
int
=
16
max_tokens
:
int
=
4096
num_blocks
:
int
=
8
*
1024
block_size
:
int
=
16
temperature
:
float
=
1.0
top_p
:
float
=
0.8
top_k
:
int
=
1
class
LLMEngine
:
"""Low-level LLM engine that handles inference execution."""
def
__init__
(
self
,
config
:
EngineConfig
):
self
.
config
=
config
# Initialize device and dtype
self
.
_init_device
()
# Initialize model engine
self
.
model_engine
=
InferEngine
(
model_path
=
config
.
model_path
,
device
=
self
.
device
,
distributed_config
=
DistConfig
(
config
.
tensor_parallel_size
),
)
# Load model weights
load_model_state_dict_by_file
(
self
.
model_engine
,
config
.
model_path
,
dtype
=
self
.
model_engine
.
config
.
dtype
)
# Initialize tokenizer
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
config
.
model_path
,
trust_remote_code
=
True
)
self
.
_fix_tokenizer_decoder
()
# Initialize KV cache
cache_config
=
PagedKVCacheConfig
(
num_blocks
=
config
.
num_blocks
,
block_size
=
config
.
block_size
)
self
.
model_engine
.
reset_cache
(
cache_config
)
# Initialize scheduler
self
.
scheduler
=
Scheduler
(
max_batch_size
=
config
.
max_batch_size
,
num_blocks
=
config
.
num_blocks
,
block_size
=
config
.
block_size
,
)
# Get EOS token IDs from model config
self
.
eos_token_ids
=
self
.
model_engine
.
config
.
eos_token_id
or
[]
if
isinstance
(
self
.
eos_token_ids
,
int
):
self
.
eos_token_ids
=
[
self
.
eos_token_ids
]
logger
.
info
(
f
"LLMEngine initialized with model at
{
config
.
model_path
}
"
f
"on device
{
config
.
device
}
"
)
def
_init_device
(
self
):
"""Initialize infinicore device and dtype."""
supported_devices
=
[
"cpu"
,
"cuda"
,
"mlu"
,
"moore"
]
device_str
=
self
.
config
.
device
if
device_str
not
in
supported_devices
:
raise
ValueError
(
f
"Unsupported device: '
{
device_str
}
'. "
f
"Supported devices:
{
supported_devices
}
"
)
self
.
device
=
infinicore
.
device
(
device_str
,
0
)
dtype_map
=
{
"float32"
:
infinicore
.
float32
,
"float16"
:
infinicore
.
float16
,
"bfloat16"
:
infinicore
.
bfloat16
,
}
if
self
.
config
.
dtype
not
in
dtype_map
:
raise
ValueError
(
f
"Unsupported dtype: '
{
self
.
config
.
dtype
}
'. "
f
"Supported dtypes:
{
list
(
dtype_map
.
keys
())
}
"
)
self
.
dtype
=
dtype_map
[
self
.
config
.
dtype
]
def
_fix_tokenizer_decoder
(
self
):
"""Fix tokenizer decoder for llama models."""
if
"llama"
in
self
.
model_engine
.
config
.
model_type
.
lower
():
backend
=
getattr
(
self
.
tokenizer
,
"backend_tokenizer"
,
None
)
target
=
getattr
(
backend
,
"_tokenizer"
,
backend
)
norm
=
getattr
(
target
,
"normalizer"
,
None
)
dec
=
getattr
(
target
,
"decoder"
,
None
)
sn
=
repr
(
norm
)[:
800
]
if
norm
is
not
None
else
""
sd
=
repr
(
dec
)[:
800
]
if
dec
is
not
None
else
""
has_prepend
=
"Prepend"
in
sn
has_strip
=
"Strip"
in
sd
if
has_prepend
and
has_strip
:
target
.
decoder
=
_dec
.
Sequence
(
[
_dec
.
Replace
(
"▁"
,
" "
),
_dec
.
ByteFallback
(),
_dec
.
Fuse
(),
]
)
def
add_request
(
self
,
request
:
InferenceRequest
):
"""Add a request to the scheduler."""
self
.
scheduler
.
add_request
(
request
)
def
step
(
self
)
->
List
[
InferenceRequest
]:
"""Run one inference step.
Returns:
List of requests that were processed in this step.
"""
# Schedule requests
scheduler_output
=
self
.
scheduler
.
schedule
()
if
scheduler_output
is
None
or
not
scheduler_output
.
scheduled_requests
:
return
[]
# Build model inputs
model_input_dict
=
scheduler_output
.
build_model_inputs
(
self
.
config
.
temperature
,
self
.
config
.
top_p
,
self
.
config
.
top_k
)
model_input
=
self
.
_prepare_model_input
(
model_input_dict
)
# Run inference
sampled_tokens
=
self
.
model_engine
.
forward
(
**
model_input
)
sampled_tokens_list
=
sampled_tokens
.
to_numpy
().
tolist
()
# Update request status
self
.
_update_requests
(
scheduler_output
.
is_prefill
,
scheduler_output
.
scheduled_requests
,
sampled_tokens_list
,
)
return
scheduler_output
.
scheduled_requests
def
_prepare_model_input
(
self
,
model_input_dict
:
dict
)
->
dict
:
"""Convert model input dict to infinicore tensors."""
model_input
=
{}
for
key
,
value
in
model_input_dict
.
items
():
if
key
==
"input_ids"
:
model_input
[
key
]
=
infinicore
.
from_list
([
value
],
dtype
=
infinicore
.
int64
)
elif
key
in
[
"position_ids"
,
"past_kv_lengths"
,
"total_kv_lengths"
,
"input_offsets"
,
"slot_mapping"
,
]:
model_input
[
key
]
=
infinicore
.
from_list
(
value
,
dtype
=
infinicore
.
int64
)
elif
key
==
"block_tables"
:
model_input
[
key
]
=
infinicore
.
from_list
(
value
,
dtype
=
infinicore
.
int64
)
else
:
model_input
[
key
]
=
value
return
model_input
def
_update_requests
(
self
,
is_prefill
:
bool
,
requests
:
List
[
InferenceRequest
],
sampled_tokens
:
List
[
int
],
):
"""Update request status after inference step."""
if
is_prefill
:
self
.
scheduler
.
cache_manager
.
reset_req_blocks
()
for
req
,
token_id
in
zip
(
requests
,
sampled_tokens
):
req
.
generated_token_ids
.
append
(
token_id
)
if
req
.
is_prefill
:
req
.
is_prefill
=
False
token_text
=
self
.
tokenizer
.
decode
(
token_id
)
req
.
generated_text
+=
token_text
if
self
.
_check_request_finished
(
req
,
token_id
):
req
.
mark_finished
(
req
.
finish_reason
)
# Put output in queue if it exists (for async streaming)
if
req
.
_output_queue
is
not
None
:
output
=
TokenOutput
(
request_id
=
req
.
request_id
,
token_id
=
token_id
,
token_text
=
token_text
,
finished
=
req
.
is_finished
(),
finish_reason
=
req
.
finish_reason
,
generated_text
=
req
.
generated_text
,
)
req
.
output_queue
.
sync_q
.
put
(
output
)
self
.
scheduler
.
complete_requests
(
requests
)
def
_check_request_finished
(
self
,
req
:
InferenceRequest
,
token_id
:
int
)
->
bool
:
"""Check if request generation is finished."""
max_tokens
=
req
.
sampling_params
.
max_tokens
if
max_tokens
and
req
.
get_num_generated_tokens
()
>=
max_tokens
:
req
.
finish_reason
=
FinishReason
.
LENGTH
return
True
# Check EOS token
eos_ids
=
req
.
eos_token_ids
or
self
.
eos_token_ids
if
eos_ids
and
token_id
in
eos_ids
:
req
.
finish_reason
=
FinishReason
.
EOS_TOKEN
return
True
# Check stop strings
stop_strings
=
req
.
sampling_params
.
stop
or
[]
for
stop_str
in
stop_strings
:
if
req
.
generated_text
.
endswith
(
stop_str
):
req
.
finish_reason
=
FinishReason
.
STOP_STRING
return
True
return
False
def
tokenize
(
self
,
text
:
str
)
->
List
[
int
]:
"""Tokenize text to token IDs."""
return
self
.
tokenizer
.
encode
(
text
)
def
detokenize
(
self
,
token_ids
:
List
[
int
])
->
str
:
"""Detokenize token IDs to text."""
return
self
.
tokenizer
.
decode
(
token_ids
)
def
apply_chat_template
(
self
,
messages
:
List
[
dict
],
add_generation_prompt
:
bool
=
True
,
)
->
str
:
"""Apply chat template to messages."""
return
self
.
tokenizer
.
apply_chat_template
(
conversation
=
messages
,
add_generation_prompt
=
add_generation_prompt
,
tokenize
=
False
,
)
class
LLM
:
"""High-level LLM interface for batch generation."""
def
__init__
(
self
,
model_path
:
str
,
device
:
str
=
"cuda"
,
dtype
:
str
=
"float16"
,
tensor_parallel_size
:
int
=
1
,
max_batch_size
:
int
=
16
,
max_tokens
:
int
=
4096
,
num_blocks
:
int
=
8
*
1024
,
block_size
:
int
=
16
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
):
"""Initialize LLM.
Args:
model_path: Path to the model directory.
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
"""
config
=
EngineConfig
(
model_path
=
model_path
,
device
=
device
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
max_batch_size
=
max_batch_size
,
max_tokens
=
max_tokens
,
num_blocks
=
num_blocks
,
block_size
=
block_size
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
)
self
.
engine
=
LLMEngine
(
config
)
self
.
config
=
config
def
generate
(
self
,
prompts
:
Union
[
str
,
List
[
str
]],
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
use_tqdm
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
"""Generate completions for the given prompts.
Args:
prompts: A single prompt string or list of prompt strings.
sampling_params: Sampling parameters for generation.
use_tqdm: Whether to show progress bar.
Returns:
List of RequestOutput objects containing generated text.
"""
if
isinstance
(
prompts
,
str
):
prompts
=
[
prompts
]
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
(
max_tokens
=
self
.
config
.
max_tokens
)
elif
sampling_params
.
max_tokens
is
None
:
sampling_params
=
sampling_params
.
clone
()
sampling_params
.
max_tokens
=
self
.
config
.
max_tokens
requests
=
[]
for
prompt
in
prompts
:
request_id
=
f
"cmpl-
{
uuid
.
uuid4
().
hex
}
"
token_ids
=
self
.
engine
.
tokenize
(
prompt
)
req
=
InferenceRequest
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
token_ids
,
sampling_params
=
sampling_params
,
eos_token_ids
=
self
.
engine
.
eos_token_ids
,
)
requests
.
append
(
req
)
self
.
engine
.
add_request
(
req
)
# Run inference until all requests are finished
if
use_tqdm
:
try
:
from
tqdm
import
tqdm
pbar
=
tqdm
(
total
=
len
(
requests
),
desc
=
"Generating"
)
except
ImportError
:
pbar
=
None
use_tqdm
=
False
else
:
pbar
=
None
finished_count
=
0
while
finished_count
<
len
(
requests
):
self
.
engine
.
step
()
new_finished
=
sum
(
1
for
req
in
requests
if
req
.
is_finished
())
if
use_tqdm
and
pbar
and
new_finished
>
finished_count
:
pbar
.
update
(
new_finished
-
finished_count
)
finished_count
=
new_finished
if
pbar
:
pbar
.
close
()
outputs
=
[
req
.
to_request_output
()
for
req
in
requests
]
return
outputs
def
chat
(
self
,
messages
:
Union
[
List
[
dict
],
List
[
List
[
dict
]]],
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
use_tqdm
:
bool
=
True
,
)
->
List
[
RequestOutput
]:
"""Generate chat completions for the given messages.
Args:
messages: A single conversation (list of message dicts) or
a list of conversations.
sampling_params: Sampling parameters for generation.
use_tqdm: Whether to show progress bar.
Returns:
List of RequestOutput objects containing generated responses.
"""
if
messages
and
isinstance
(
messages
[
0
],
dict
):
messages
=
[
messages
]
prompts
=
[]
for
conversation
in
messages
:
prompt
=
self
.
engine
.
apply_chat_template
(
conversation
,
add_generation_prompt
=
True
)
prompts
.
append
(
prompt
)
return
self
.
generate
(
prompts
,
sampling_params
,
use_tqdm
)
class
AsyncLLMEngine
:
"""Asynchronous LLM engine for server use with streaming support."""
def
__init__
(
self
,
model_path
:
str
,
device
:
str
=
"cuda"
,
dtype
:
str
=
"float16"
,
tensor_parallel_size
:
int
=
1
,
max_batch_size
:
int
=
16
,
max_tokens
:
int
=
512
,
num_blocks
:
int
=
8
*
1024
,
block_size
:
int
=
16
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
):
"""Initialize AsyncLLMEngine.
Args:
model_path: Path to the model directory.
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_batch_size: Maximum batch size for inference.
max_tokens: Default maximum tokens to generate.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
"""
config
=
EngineConfig
(
model_path
=
model_path
,
device
=
device
,
dtype
=
dtype
,
tensor_parallel_size
=
tensor_parallel_size
,
max_batch_size
=
max_batch_size
,
max_tokens
=
max_tokens
,
num_blocks
=
num_blocks
,
block_size
=
block_size
,
temperature
=
temperature
,
top_p
=
top_p
,
top_k
=
top_k
,
)
self
.
engine
=
LLMEngine
(
config
)
self
.
config
=
config
self
.
_running
=
False
self
.
_step_thread
:
Optional
[
threading
.
Thread
]
=
None
def
start
(
self
):
"""Start the background inference loop."""
if
self
.
_running
:
logger
.
warning
(
"AsyncLLMEngine is already running"
)
return
self
.
_running
=
True
self
.
_step_thread
=
threading
.
Thread
(
target
=
self
.
_step_loop
,
daemon
=
True
,
name
=
"AsyncLLMEngineStepThread"
)
self
.
_step_thread
.
start
()
logger
.
info
(
"AsyncLLMEngine started"
)
def
stop
(
self
):
"""Stop the background inference loop."""
if
not
self
.
_running
:
logger
.
warning
(
"AsyncLLMEngine is not running"
)
return
self
.
_running
=
False
if
self
.
_step_thread
:
self
.
_step_thread
.
join
(
timeout
=
5
)
logger
.
info
(
"AsyncLLMEngine stopped"
)
def
_step_loop
(
self
):
"""Background loop that runs inference steps."""
while
self
.
_running
:
try
:
requests
=
self
.
engine
.
step
()
if
not
requests
:
time
.
sleep
(
0.01
)
except
Exception
as
e
:
logger
.
error
(
f
"Error in step loop:
{
e
}
"
,
exc_info
=
True
)
self
.
_running
=
False
break
def
add_request
(
self
,
prompt
:
Optional
[
str
]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
# For server use
request_data
:
Optional
[
dict
]
=
None
,
http_request
:
Optional
[
any
]
=
None
,
)
->
InferenceRequest
:
"""Add a request to the engine.
Args:
prompt: Text prompt for generation.
prompt_token_ids: Pre-tokenized prompt.
sampling_params: Sampling parameters.
request_id: Optional request ID.
request_data: Optional request data dict (for server use).
http_request: Optional HTTP request object (for server use).
Returns:
The created InferenceRequest object.
"""
if
request_id
is
None
:
request_id
=
f
"cmpl-
{
uuid
.
uuid4
().
hex
}
"
if
prompt_token_ids
is
None
and
prompt
is
not
None
:
prompt_token_ids
=
self
.
engine
.
tokenize
(
prompt
)
if
sampling_params
is
None
:
sampling_params
=
SamplingParams
(
max_tokens
=
self
.
config
.
max_tokens
)
elif
sampling_params
.
max_tokens
is
None
:
sampling_params
=
sampling_params
.
clone
()
sampling_params
.
max_tokens
=
self
.
config
.
max_tokens
request
=
InferenceRequest
(
request_id
=
request_id
,
prompt
=
prompt
,
prompt_token_ids
=
prompt_token_ids
,
sampling_params
=
sampling_params
,
eos_token_ids
=
self
.
engine
.
eos_token_ids
,
request_data
=
request_data
,
http_request
=
http_request
,
)
# Initialize output queue for streaming
_
=
request
.
output_queue
self
.
engine
.
add_request
(
request
)
return
request
def
add_chat_request
(
self
,
messages
:
List
[
dict
],
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
request_id
:
Optional
[
str
]
=
None
,
request_data
:
Optional
[
dict
]
=
None
,
http_request
:
Optional
[
any
]
=
None
,
)
->
InferenceRequest
:
"""Add a chat request to the engine.
Args:
messages: List of message dicts (chat conversation).
sampling_params: Sampling parameters.
request_id: Optional request ID.
request_data: Optional request data dict.
http_request: Optional HTTP request object.
Returns:
The created InferenceRequest object.
"""
prompt
=
self
.
engine
.
apply_chat_template
(
messages
,
add_generation_prompt
=
True
)
return
self
.
add_request
(
prompt
=
prompt
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
request_data
=
request_data
,
http_request
=
http_request
,
)
async
def
stream_request
(
self
,
request
:
InferenceRequest
,
timeout
:
float
=
100.0
,
)
->
AsyncIterator
[
TokenOutput
]:
"""Stream tokens from a request.
Args:
request: The inference request to stream from.
timeout: Timeout for waiting on each token.
Yields:
TokenOutput objects for each generated token.
"""
import
asyncio
while
True
:
if
request
.
is_finished
()
and
request
.
output_queue
.
async_q
.
empty
():
break
try
:
token_output
=
await
asyncio
.
wait_for
(
request
.
output_queue
.
async_q
.
get
(),
timeout
=
timeout
)
request
.
output_queue
.
async_q
.
task_done
()
yield
token_output
if
token_output
.
finished
:
break
except
asyncio
.
TimeoutError
:
if
request
.
is_finished
():
break
continue
except
asyncio
.
CancelledError
:
request
.
mark_canceled
()
break
except
Exception
as
e
:
logger
.
error
(
f
"Error streaming request
{
request
.
request_id
}
:
{
e
}
"
)
await
asyncio
.
sleep
(
0.01
)
python/infinilm/llm/request.py
0 → 100644
View file @
97870d3e
"""
Request and Output - Data structures for inference requests and outputs.
"""
from
enum
import
Enum
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
,
Any
import
time
import
janus
from
infinilm.llm.sampling_params
import
SamplingParams
class
RequestStatus
(
Enum
):
"""Status of an inference request."""
WAITING
=
"waiting"
RUNNING
=
"running"
FINISHED
=
"finished"
CANCELED
=
"canceled"
FAILED
=
"failed"
TIMEOUT
=
"timeout"
class
FinishReason
(
Enum
):
"""Reason for finishing generation."""
STOP
=
"stop"
LENGTH
=
"length"
EOS_TOKEN
=
"eos_token"
STOP_STRING
=
"stop_string"
TIMEOUT
=
"timeout"
CANCELED
=
"canceled"
ERROR
=
"error"
@
dataclass
class
RequestOutput
:
"""Output from a single generation request.
Attributes:
request_id: Unique identifier for the request.
prompt: Original prompt text.
prompt_token_ids: Token IDs of the prompt.
outputs: List of generated outputs (for beam search, multiple outputs possible).
finished: Whether generation is complete.
finish_reason: Reason for finishing.
"""
request_id
:
str
prompt
:
Optional
[
str
]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
outputs
:
List
[
"CompletionOutput"
]
=
field
(
default_factory
=
list
)
finished
:
bool
=
False
finish_reason
:
Optional
[
FinishReason
]
=
None
@
dataclass
class
CompletionOutput
:
"""Single completion output.
Attributes:
index: Index of this output (for beam search).
text: Generated text.
token_ids: Generated token IDs.
finish_reason: Reason for finishing.
"""
index
:
int
=
0
text
:
str
=
""
token_ids
:
List
[
int
]
=
field
(
default_factory
=
list
)
finish_reason
:
Optional
[
FinishReason
]
=
None
@
dataclass
class
TokenOutput
:
"""Output for a single generated token.
Attributes:
request_id: Unique identifier for the request.
token_id: Generated token ID.
token_text: Decoded text of the token.
finished: Whether generation is complete.
finish_reason: Reason for finishing.
generated_text: Full generated text so far.
"""
request_id
:
str
token_id
:
int
token_text
:
str
finished
:
bool
=
False
finish_reason
:
Optional
[
FinishReason
]
=
None
generated_text
:
str
=
""
class
InferenceRequest
:
"""Internal inference request object for managing generation state and resources."""
def
__init__
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
eos_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
# For server use
request_data
:
Optional
[
dict
]
=
None
,
http_request
:
Optional
[
Any
]
=
None
,
):
# Request metadata
self
.
request_id
:
str
=
request_id
self
.
prompt
:
Optional
[
str
]
=
prompt
self
.
prompt_token_ids
:
List
[
int
]
=
prompt_token_ids
or
[]
self
.
prompt_length
:
int
=
len
(
self
.
prompt_token_ids
)
self
.
arrival_time
:
float
=
arrival_time
or
time
.
time
()
self
.
finished_time
:
Optional
[
float
]
=
None
# Sampling parameters
self
.
sampling_params
:
SamplingParams
=
sampling_params
or
SamplingParams
()
# EOS token IDs (from model config)
self
.
eos_token_ids
:
List
[
int
]
=
eos_token_ids
or
[]
# Generation state
self
.
generated_token_ids
:
List
[
int
]
=
[]
self
.
generated_text
:
str
=
""
self
.
is_prefill
:
bool
=
True
self
.
status
:
RequestStatus
=
RequestStatus
.
WAITING
self
.
finish_reason
:
Optional
[
FinishReason
]
=
None
self
.
priority
:
int
=
0
# KV cache management
self
.
cache_id
:
Optional
[
int
]
=
None
self
.
block_table
:
List
[
int
]
=
[]
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
num_cached_tokens
:
int
=
0
self
.
num_blocks
:
int
=
0
# For server use
self
.
request_data
:
Optional
[
dict
]
=
request_data
self
.
http_request
:
Optional
[
Any
]
=
http_request
# Output management (for async streaming)
self
.
_output_queue
:
Optional
[
janus
.
Queue
]
=
None
@
property
def
output_queue
(
self
)
->
janus
.
Queue
:
"""Lazy initialization of output queue."""
if
self
.
_output_queue
is
None
:
self
.
_output_queue
=
janus
.
Queue
()
return
self
.
_output_queue
def
get_prompt_length
(
self
)
->
int
:
return
self
.
prompt_length
def
get_input_tokens
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
def
get_num_generated_tokens
(
self
)
->
int
:
return
len
(
self
.
generated_token_ids
)
def
get_total_length
(
self
)
->
int
:
return
self
.
prompt_length
+
len
(
self
.
generated_token_ids
)
def
get_all_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
generated_token_ids
def
get_num_blocks_required
(
self
,
block_size
:
int
)
->
int
:
total_tokens
=
self
.
get_total_length
()
return
(
total_tokens
+
block_size
-
1
)
//
block_size
def
get_max_tokens
(
self
)
->
Optional
[
int
]:
return
self
.
sampling_params
.
max_tokens
def
is_finished
(
self
)
->
bool
:
return
self
.
status
in
[
RequestStatus
.
FINISHED
,
RequestStatus
.
CANCELED
,
RequestStatus
.
FAILED
,
RequestStatus
.
TIMEOUT
,
]
def
mark_finished
(
self
,
reason
:
FinishReason
):
"""Mark the request as finished with the given reason."""
self
.
status
=
RequestStatus
.
FINISHED
self
.
finish_reason
=
reason
self
.
finished_time
=
time
.
time
()
def
mark_failed
(
self
,
reason
:
FinishReason
=
FinishReason
.
ERROR
):
"""Mark the request as failed."""
self
.
status
=
RequestStatus
.
FAILED
self
.
finish_reason
=
reason
self
.
finished_time
=
time
.
time
()
def
mark_canceled
(
self
):
"""Mark the request as canceled."""
self
.
status
=
RequestStatus
.
CANCELED
self
.
finish_reason
=
FinishReason
.
CANCELED
self
.
finished_time
=
time
.
time
()
def
mark_timeout
(
self
):
"""Mark the request as timed out."""
self
.
status
=
RequestStatus
.
TIMEOUT
self
.
finish_reason
=
FinishReason
.
TIMEOUT
self
.
finished_time
=
time
.
time
()
async
def
close
(
self
):
"""Close the output queue and clean up resources."""
if
self
.
_output_queue
is
not
None
:
await
self
.
_output_queue
.
async_q
.
join
()
self
.
_output_queue
.
close
()
await
self
.
_output_queue
.
wait_closed
()
def
to_request_output
(
self
)
->
RequestOutput
:
"""Convert to RequestOutput for external use."""
return
RequestOutput
(
request_id
=
self
.
request_id
,
prompt
=
self
.
prompt
,
prompt_token_ids
=
self
.
prompt_token_ids
,
outputs
=
[
CompletionOutput
(
index
=
0
,
text
=
self
.
generated_text
,
token_ids
=
self
.
generated_token_ids
.
copy
(),
finish_reason
=
self
.
finish_reason
,
)
],
finished
=
self
.
is_finished
(),
finish_reason
=
self
.
finish_reason
,
)
python/infinilm/llm/sampling_params.py
0 → 100644
View file @
97870d3e
"""
Sampling Parameters - Configuration for text generation sampling.
"""
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
@
dataclass
class
SamplingParams
:
"""Sampling parameters for text generation."""
temperature
:
float
=
1.0
top_p
:
float
=
0.8
top_k
:
int
=
1
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
List
[
str
]]
=
None
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
def
__post_init__
(
self
):
if
self
.
stop
is
None
:
self
.
stop
=
[]
if
self
.
stop_token_ids
is
None
:
self
.
stop_token_ids
=
[]
def
clone
(
self
)
->
"SamplingParams"
:
"""Create a copy of this SamplingParams instance."""
return
SamplingParams
(
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
max_tokens
=
self
.
max_tokens
,
stop
=
self
.
stop
.
copy
()
if
self
.
stop
else
None
,
stop_token_ids
=
self
.
stop_token_ids
.
copy
()
if
self
.
stop_token_ids
else
None
,
)
python/infinilm/llm/scheduler.py
0 → 100644
View file @
97870d3e
"""
Scheduler - Request scheduling and batch management with Paged Attention KV Cache.
"""
import
queue
import
janus
import
logging
from
typing
import
List
,
Optional
from
infinilm.llm.request
import
RequestStatus
,
InferenceRequest
from
infinilm.llm.cache_manager
import
BlockManager
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerOutput
:
"""Scheduler output containing scheduled requests and execution phase info."""
def
__init__
(
self
,
scheduled_requests
:
List
[
InferenceRequest
],
is_prefill
:
bool
=
False
,
):
self
.
scheduled_requests
=
scheduled_requests
self
.
num_requests
=
len
(
scheduled_requests
)
self
.
is_prefill
=
is_prefill
def
build_model_inputs
(
self
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
):
"""Construct model inputs for prefill or decode phase.
Prefill phase:
- input_ids: Flattened token list (excluding cached tokens)
- position_ids: Position IDs for new tokens in complete sequence
- past_kv_lengths: Number of cached tokens per request
- total_kv_lengths: Total tokens (cached + new) per request
- input_offsets: Start position of each request in flattened array
- block_tables: Padded block_table for each request
- slot_mapping: Token to slot mappings
Decode phase:
- input_ids: Only last generated token per request
- position_ids: Position of last token in complete sequence
- past_kv_lengths: Number of cached tokens per request
- total_kv_lengths: Total sequence length per request
- input_offsets: Offsets for each request
- block_tables: Padded block_table for each request
- slot_mapping: Single slot per request
"""
if
not
self
.
scheduled_requests
:
raise
RuntimeError
(
"build_model_inputs called with empty scheduled_requests"
)
tokens
=
[]
seq_lens
=
[]
seq_offsets
=
[
0
]
block_tables
=
[]
slot_mapping
=
[]
cached_lens
=
[]
position_ids
=
[]
max_block_table_len
=
max
(
len
(
req
.
block_table
)
for
req
in
self
.
scheduled_requests
)
current_offset
=
0
for
req
in
self
.
scheduled_requests
:
num_cached
=
req
.
num_cached_tokens
if
self
.
is_prefill
:
# Prefill phase
req_tokens
=
req
.
get_input_tokens
()
tokens_to_compute
=
req_tokens
[
num_cached
:]
tokens
.
extend
(
tokens_to_compute
)
seq_len
=
len
(
tokens_to_compute
)
seq_lens
.
append
(
len
(
req_tokens
))
current_offset
+=
seq_len
seq_offsets
.
append
(
current_offset
)
slot_mapping
.
extend
(
req
.
slot_mapping
)
cached_lens
.
append
(
num_cached
)
position_ids
.
extend
(
range
(
num_cached
,
num_cached
+
seq_len
))
else
:
# Decode phase
last_token
=
req
.
generated_token_ids
[
-
1
]
tokens
.
append
(
last_token
)
seq_lens
.
append
(
req
.
get_total_length
())
current_offset
+=
1
seq_offsets
.
append
(
current_offset
)
slot_mapping
.
extend
(
req
.
slot_mapping
)
cached_lens
.
append
(
num_cached
)
position_ids
.
append
(
req
.
get_total_length
()
-
1
)
# Pad block_table to same length
padded_block_table
=
req
.
block_table
+
[
-
1
]
*
(
max_block_table_len
-
len
(
req
.
block_table
)
)
block_tables
.
append
(
padded_block_table
)
return
{
"input_ids"
:
tokens
,
"position_ids"
:
position_ids
,
"past_kv_lengths"
:
cached_lens
,
"total_kv_lengths"
:
seq_lens
,
"input_offsets"
:
seq_offsets
,
"block_tables"
:
block_tables
,
"slot_mapping"
:
slot_mapping
,
"temperature"
:
temperature
,
"top_k"
:
top_k
,
"top_p"
:
top_p
,
}
class
Scheduler
:
"""Request scheduler with integrated BlockManager for KV cache management.
Scheduling logic:
1. Running queue: Check for new blocks needed, update slot_mapping
2. Waiting queue: Try block reuse (prefix caching), allocate new blocks
3. Reference counting: Free blocks when requests complete
"""
def
__init__
(
self
,
max_batch_size
:
int
=
16
,
num_blocks
:
int
=
8
*
1024
,
block_size
:
int
=
16
,
):
self
.
waiting_queue
=
janus
.
Queue
()
self
.
running_queue
=
janus
.
Queue
()
self
.
max_batch_size
=
max_batch_size
self
.
cache_manager
=
BlockManager
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
self
.
block_size
=
block_size
def
add_request
(
self
,
request
:
InferenceRequest
):
if
request
is
not
None
:
request
.
status
=
RequestStatus
.
WAITING
self
.
waiting_queue
.
sync_q
.
put
(
request
)
def
schedule
(
self
)
->
Optional
[
SchedulerOutput
]:
"""Schedule and return batch of requests to execute."""
scheduled_requests
=
[]
is_prefill
=
False
# Process Waiting queue (prefill phase)
while
len
(
scheduled_requests
)
<
self
.
max_batch_size
:
try
:
req
=
self
.
waiting_queue
.
sync_q
.
get_nowait
()
except
queue
.
Empty
:
break
req_tokens
=
req
.
get_input_tokens
()
num_required_blocks
=
req
.
get_num_blocks_required
(
self
.
block_size
)
if
not
self
.
cache_manager
.
can_allocate
(
num_required_blocks
):
if
not
self
.
cache_manager
.
try_free_blocks
(
num_required_blocks
):
raise
RuntimeError
(
"No available cache blocks"
)
# Allocate blocks with automatic prefix caching support
req
.
block_table
,
req
.
slot_mapping
,
req
.
num_cached_tokens
=
(
self
.
cache_manager
.
allocate_blocks
(
req_tokens
,
req
.
block_table
)
)
req
.
num_blocks
=
len
(
req
.
block_table
)
req
.
status
=
RequestStatus
.
RUNNING
scheduled_requests
.
append
(
req
)
# Return prefill batch if any waiting requests were scheduled
if
scheduled_requests
:
is_prefill
=
True
return
SchedulerOutput
(
scheduled_requests
=
scheduled_requests
,
is_prefill
=
is_prefill
,
)
# Process Running queue (decode phase)
while
len
(
scheduled_requests
)
<
self
.
max_batch_size
:
try
:
req
=
self
.
running_queue
.
sync_q
.
get_nowait
()
except
queue
.
Empty
:
break
# Decode phase: allocate slot for newly generated token
try
:
req
.
block_table
,
new_slot
=
self
.
cache_manager
.
append_slot
(
req
.
block_table
,
req
.
get_total_length
(),
req
.
get_all_token_ids
()
)
req
.
slot_mapping
=
[
new_slot
]
req
.
num_blocks
=
len
(
req
.
block_table
)
req
.
num_cached_tokens
=
req
.
get_total_length
()
-
1
scheduled_requests
.
append
(
req
)
except
RuntimeError
as
e
:
raise
RuntimeError
(
"No available cache blocks"
)
from
e
# Return decode batch if any running requests were scheduled
if
scheduled_requests
:
is_prefill
=
False
return
SchedulerOutput
(
scheduled_requests
=
scheduled_requests
,
is_prefill
=
is_prefill
,
)
return
None
def
complete_requests
(
self
,
requests
:
List
[
InferenceRequest
]):
"""Handle completed requests and free their blocks."""
for
req
in
requests
:
if
req
.
status
in
[
RequestStatus
.
FINISHED
,
RequestStatus
.
CANCELED
,
RequestStatus
.
FAILED
,
RequestStatus
.
TIMEOUT
,
]:
if
req
.
block_table
:
self
.
cache_manager
.
free_blocks
(
req
.
block_table
)
if
req
.
status
==
RequestStatus
.
CANCELED
:
logger
.
info
(
f
"Request
{
req
.
request_id
[:
8
]
}
... canceled:
{
req
.
finish_reason
}
"
)
elif
req
.
status
==
RequestStatus
.
FAILED
:
logger
.
error
(
f
"Request
{
req
.
request_id
[:
8
]
}
... failed:
{
req
.
finish_reason
}
"
)
elif
req
.
status
==
RequestStatus
.
TIMEOUT
:
logger
.
error
(
f
"Request
{
req
.
request_id
[:
8
]
}
... timed out:
{
req
.
finish_reason
}
"
)
else
:
# Still running, put back in running queue
self
.
running_queue
.
sync_q
.
put
(
req
)
def
get_cache_stats
(
self
)
->
dict
:
"""Get cache statistics."""
return
{
"num_blocks"
:
self
.
cache_manager
.
num_blocks
,
"block_size"
:
self
.
cache_manager
.
block_size
,
"num_free_blocks"
:
self
.
cache_manager
.
get_num_free_blocks
(),
"num_req_blocks"
:
len
(
self
.
cache_manager
.
req_block_ids
),
"num_used_blocks"
:
len
(
self
.
cache_manager
.
used_block_ids
),
}
python/infinilm/server/inference_server.py
0 → 100644
View file @
97870d3e
"""
Inference Server - HTTP API server for LLM inference.
"""
from
contextlib
import
asynccontextmanager
import
sys
import
time
import
json
import
uuid
import
argparse
import
uvicorn
import
logging
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
infinilm.llm
import
AsyncLLMEngine
,
SamplingParams
,
FinishReason
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_STREAM_TIMEOUT
=
100.0
DEFAULT_REQUEST_TIMEOUT
=
1000.0
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
):
"""Generate JSON chunk for streaming response."""
delta
=
{}
if
content
:
delta
[
"content"
]
=
content
if
role
:
delta
[
"role"
]
=
role
return
{
"id"
:
id_
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"model"
:
"jiuge"
,
"system_fingerprint"
:
None
,
"choices"
:
[
{
"index"
:
0
,
"text"
:
content
,
"delta"
:
delta
,
"logprobs"
:
None
,
"finish_reason"
:
finish_reason
,
}
],
}
class
InferenceServer
:
"""HTTP server for LLM inference."""
def
__init__
(
self
,
model_path
:
str
,
device
:
str
=
"cuda"
,
dtype
:
str
=
"float16"
,
tensor_parallel_size
:
int
=
1
,
max_tokens
:
int
=
4096
,
max_batch_size
:
int
=
16
,
num_blocks
:
int
=
8
*
1024
,
block_size
:
int
=
16
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
host
:
str
=
"0.0.0.0"
,
port
:
int
=
8000
,
):
"""Initialize inference server.
Args:
model_path: Path to the model directory.
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_tokens: Default maximum tokens to generate.
max_batch_size: Maximum batch size for inference.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
host: Server host address.
port: Server port number.
"""
self
.
model_path
=
model_path
self
.
device
=
device
self
.
dtype
=
dtype
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
max_tokens
=
max_tokens
self
.
max_batch_size
=
max_batch_size
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
host
=
host
self
.
port
=
port
self
.
engine
:
AsyncLLMEngine
=
None
def
start
(
self
):
"""Start the HTTP server."""
app
=
self
.
_create_app
()
logger
.
info
(
f
"Starting API Server at
{
self
.
host
}
:
{
self
.
port
}
..."
)
uvicorn
.
run
(
app
,
host
=
self
.
host
,
port
=
self
.
port
)
logger
.
info
(
"Inference Server stopped"
)
def
_create_app
(
self
):
"""Create FastAPI application."""
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
self
.
engine
=
AsyncLLMEngine
(
model_path
=
self
.
model_path
,
device
=
self
.
device
,
dtype
=
self
.
dtype
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
max_batch_size
=
self
.
max_batch_size
,
max_tokens
=
self
.
max_tokens
,
num_blocks
=
self
.
num_blocks
,
block_size
=
self
.
block_size
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
)
self
.
engine
.
start
()
logger
.
info
(
f
"Engine initialized with model at
{
self
.
model_path
}
"
)
yield
self
.
engine
.
stop
()
app
=
FastAPI
(
lifespan
=
lifespan
)
self
.
_register_routes
(
app
)
return
app
def
_register_routes
(
self
,
app
:
FastAPI
):
"""Register API routes."""
@
app
.
post
(
"/chat/completions"
)
async
def
chat_completions
(
request
:
Request
):
try
:
data
=
await
request
.
json
()
logger
.
debug
(
f
"Received request data:
{
data
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to parse request JSON:
{
e
}
"
)
return
JSONResponse
(
content
=
{
"error"
:
"Invalid JSON"
},
status_code
=
400
)
if
not
data
.
get
(
"messages"
):
if
not
data
.
get
(
"prompt"
):
return
JSONResponse
(
content
=
{
"error"
:
"No message provided"
},
status_code
=
400
)
else
:
data
[
"messages"
]
=
[{
"role"
:
"user"
,
"content"
:
data
.
get
(
"prompt"
)}]
stream
=
data
.
get
(
"stream"
,
False
)
request_id
=
f
"cmpl-
{
uuid
.
uuid4
().
hex
}
"
if
stream
:
return
StreamingResponse
(
self
.
_stream_chat
(
request_id
,
data
,
request
),
media_type
=
"text/event-stream"
,
)
else
:
response
=
await
self
.
_chat
(
request_id
,
data
,
request
)
if
isinstance
(
response
,
JSONResponse
):
return
response
return
JSONResponse
(
content
=
response
)
@
app
.
get
(
"/health"
)
async
def
health
():
return
{
"status"
:
"healthy"
}
@
app
.
get
(
"/v1/models"
)
async
def
list_models
():
return
{
"object"
:
"list"
,
"data"
:
[
{
"id"
:
"jiuge"
,
"object"
:
"model"
,
"created"
:
int
(
time
.
time
()),
"owned_by"
:
"infinilm"
,
}
],
}
def
_build_sampling_params
(
self
,
data
:
dict
)
->
SamplingParams
:
"""Build SamplingParams from request data."""
return
SamplingParams
(
temperature
=
data
.
get
(
"temperature"
,
self
.
temperature
),
top_p
=
data
.
get
(
"top_p"
,
self
.
top_p
),
top_k
=
data
.
get
(
"top_k"
,
self
.
top_k
),
max_tokens
=
data
.
get
(
"max_tokens"
,
self
.
max_tokens
),
stop
=
data
.
get
(
"stop"
),
)
async
def
_stream_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
"""Handle streaming chat request."""
req
=
None
start_time
=
time
.
time
()
try
:
messages
=
data
.
get
(
"messages"
,
[])
sampling_params
=
self
.
_build_sampling_params
(
data
)
req
=
self
.
engine
.
add_chat_request
(
messages
=
messages
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
request_data
=
data
,
http_request
=
http_request
,
)
async
for
token_output
in
self
.
engine
.
stream_request
(
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
):
# Check timeout
if
time
.
time
()
-
start_time
>
DEFAULT_REQUEST_TIMEOUT
:
logger
.
warning
(
f
"Request
{
request_id
}
timed out after
{
DEFAULT_REQUEST_TIMEOUT
}
s"
)
req
.
mark_timeout
()
error_chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
"[Request timeout]"
,
finish_reason
=
"timeout"
,
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
error_chunk
}
\n\n
"
break
# Check client disconnect
if
await
http_request
.
is_disconnected
():
logger
.
info
(
f
"Client disconnected for request
{
request_id
}
"
)
req
.
mark_canceled
()
break
# Send token
chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
token_output
.
token_text
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
chunk
}
\n\n
"
if
token_output
.
finished
:
finish_reason
=
self
.
_convert_finish_reason
(
token_output
.
finish_reason
)
chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
finish_reason
=
finish_reason
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
chunk
}
\n\n
"
break
except
Exception
as
e
:
logger
.
error
(
f
"Stream error for
{
request_id
}
:
{
e
}
"
,
exc_info
=
True
)
if
req
:
req
.
mark_failed
()
error_chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
f
"[Error:
{
str
(
e
)
}
]"
,
finish_reason
=
"error"
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
error_chunk
}
\n\n
"
finally
:
if
req
and
not
req
.
is_finished
():
req
.
mark_canceled
()
if
req
:
await
req
.
close
()
yield
"data: [DONE]
\n\n
"
async
def
_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
"""Handle non-streaming chat request."""
req
=
None
start_time
=
time
.
time
()
try
:
messages
=
data
.
get
(
"messages"
,
[])
sampling_params
=
self
.
_build_sampling_params
(
data
)
req
=
self
.
engine
.
add_chat_request
(
messages
=
messages
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
request_data
=
data
,
http_request
=
http_request
,
)
# Collect all generated tokens
output_text
=
""
async
for
token_output
in
self
.
engine
.
stream_request
(
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
):
# Check timeout
if
time
.
time
()
-
start_time
>
DEFAULT_REQUEST_TIMEOUT
:
logger
.
warning
(
f
"Request
{
request_id
}
timed out"
)
req
.
mark_timeout
()
break
# Check client disconnect
if
await
http_request
.
is_disconnected
():
logger
.
info
(
f
"Client disconnected for request
{
request_id
}
"
)
req
.
mark_canceled
()
break
output_text
+=
token_output
.
token_text
if
token_output
.
finished
:
break
output_text
=
output_text
.
strip
()
finish_reason
=
self
.
_convert_finish_reason
(
req
.
finish_reason
)
response
=
chunk_json
(
request_id
,
content
=
output_text
,
role
=
"assistant"
,
finish_reason
=
finish_reason
or
"stop"
,
)
return
response
except
Exception
as
e
:
logger
.
error
(
f
"Chat error for
{
request_id
}
:
{
e
}
"
,
exc_info
=
True
)
if
req
:
req
.
mark_failed
()
return
JSONResponse
(
content
=
{
"error"
:
str
(
e
)},
status_code
=
500
)
finally
:
if
req
and
not
req
.
is_finished
():
req
.
mark_canceled
()
if
req
:
await
req
.
close
()
def
_convert_finish_reason
(
self
,
reason
:
FinishReason
)
->
str
:
"""Convert FinishReason enum to string."""
if
reason
is
None
:
return
None
if
reason
in
(
FinishReason
.
EOS_TOKEN
,
FinishReason
.
STOP_STRING
):
return
"stop"
return
reason
.
value
def
setup_logging
(
log_level
:
str
=
"INFO"
):
"""Configure logging system with proper formatting and handlers."""
log_format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
date_format
=
"%Y-%m-%d %H:%M:%S"
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
.
upper
(),
logging
.
INFO
),
format
=
log_format
,
datefmt
=
date_format
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
),
],
force
=
True
,
)
def
parse_args
():
"""Parse command line arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
"InfiniLM Inference Server"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model directory"
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
,
help
=
"Tensor parallelism degree"
)
parser
.
add_argument
(
"--max_tokens"
,
type
=
int
,
default
=
512
,
help
=
"Maximum number of tokens to generate"
,
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
8
,
help
=
"Maximum batch size"
)
parser
.
add_argument
(
"--num_blocks"
,
type
=
int
,
default
=
8
*
1024
,
help
=
"Number of blocks for KV cache"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
16
,
help
=
"Block size for KV cache"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float32"
,
"float16"
,
"bfloat16"
],
help
=
"Data type"
,
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"Sampling temperature"
)
parser
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.8
,
help
=
"Top-p sampling parameter"
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
1
,
help
=
"Top-k sampling parameter"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Server host"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
help
=
"Use CPU"
)
parser
.
add_argument
(
"--nvidia"
,
action
=
"store_true"
,
help
=
"Use NVIDIA GPU"
)
parser
.
add_argument
(
"--metax"
,
action
=
"store_true"
,
help
=
"Use MetaX device"
)
parser
.
add_argument
(
"--moore"
,
action
=
"store_true"
,
help
=
"Use Moore device"
)
parser
.
add_argument
(
"--iluvatar"
,
action
=
"store_true"
,
help
=
"Use Iluvatar device"
)
parser
.
add_argument
(
"--cambricon"
,
action
=
"store_true"
,
help
=
"Use Cambricon device"
)
parser
.
add_argument
(
"--log_level"
,
type
=
str
,
default
=
"INFO"
,
choices
=
[
"DEBUG"
,
"INFO"
,
"WARNING"
,
"ERROR"
,
"CRITICAL"
],
help
=
"Logging level"
,
)
return
parser
.
parse_args
()
def
main
():
args
=
parse_args
()
setup_logging
(
args
.
log_level
)
if
args
.
cpu
:
device
=
"cpu"
elif
args
.
nvidia
:
device
=
"cuda"
elif
args
.
metax
:
device
=
"cuda"
elif
args
.
moore
:
device
=
"moore"
elif
args
.
iluvatar
:
device
=
"cuda"
elif
args
.
cambricon
:
device
=
"mlu"
else
:
print
(
"Usage: python infinilm.server.inference_server [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] "
"--model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE"
"
\n
"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
)
sys
.
exit
(
1
)
server
=
InferenceServer
(
model_path
=
args
.
model_path
,
device
=
device
,
dtype
=
args
.
dtype
,
tensor_parallel_size
=
args
.
tp
,
max_tokens
=
args
.
max_tokens
,
max_batch_size
=
args
.
max_batch_size
,
num_blocks
=
args
.
num_blocks
,
block_size
=
args
.
block_size
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
top_k
=
args
.
top_k
,
host
=
args
.
host
,
port
=
args
.
port
,
)
server
.
start
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment