Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
c73ff203
Unverified
Commit
c73ff203
authored
Jan 20, 2026
by
PanZezhong1725
Committed by
GitHub
Jan 20, 2026
Browse files
issue/189: add inference server support to InfiniLM (#190)
parents
de3e6b95
97870d3e
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1981 additions
and
1 deletion
+1981
-1
README.md
README.md
+22
-0
python/infinilm/__init__.py
python/infinilm/__init__.py
+21
-1
python/infinilm/llm/__init__.py
python/infinilm/llm/__init__.py
+43
-0
python/infinilm/llm/cache_manager.py
python/infinilm/llm/cache_manager.py
+268
-0
python/infinilm/llm/llm.py
python/infinilm/llm/llm.py
+646
-0
python/infinilm/llm/request.py
python/infinilm/llm/request.py
+231
-0
python/infinilm/llm/sampling_params.py
python/infinilm/llm/sampling_params.py
+35
-0
python/infinilm/llm/scheduler.py
python/infinilm/llm/scheduler.py
+248
-0
python/infinilm/server/inference_server.py
python/infinilm/server/inference_server.py
+467
-0
No files found.
README.md
View file @
c73ff203
...
@@ -88,6 +88,28 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
...
@@ -88,6 +88,28 @@ python scripts/test_ppl.py --model-path MODEL_PATH [--ndev NDEV] [--max-batch MA
python examples/jiuge.py
--nvidia
--model_path
=
/models/9G7B_MHA/
--backend
=
cpp
--tp
=
4
--batch_size
=
16
python examples/jiuge.py
--nvidia
--model_path
=
/models/9G7B_MHA/
--backend
=
cpp
--tp
=
4
--batch_size
=
16
```
```
- 推理服务测试
- 启动推理服务
```
bash
python python/infinilm/server/inference_server.py [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] --model_path=
<path
/
to
/
model_dir
>
--max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH --tp=NDEV --temperature=TEMP --top_p=TOP_P --top_k=TOP_K --host=HOST --port=PORT
```
- 单卡示例:
```
bash
CUDA_VISIBLE_DEVICES=0 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1
```
- 多卡分布式示例:
```
bash
CUDA_VISIBLE_DEVICES=0,1,2,3 python python/infinilm/server/inference_server.py --nvidia --model_path=/models/9G7B_MHA/ --max_tokens=100 --max_batch_size=32 --tp=4 --temperature=1.0 --top_p=0.8 --top_k=1
```
- 测试推理服务性能:
```
bash
python scripts/test_perf.py --verbose
```
- 运行推理基准测试(C-Eval/MMLU)
- 运行推理基准测试(C-Eval/MMLU)
```
bash
```
bash
...
...
python/infinilm/__init__.py
View file @
c73ff203
from
.models
import
AutoLlamaModel
from
.models
import
AutoLlamaModel
from
.
import
distributed
from
.
import
distributed
from
.
import
cache
from
.
import
cache
from
.
import
llm
__all__
=
[
"AutoLlamaModel"
,
"distributed"
,
"cache"
]
from
.llm
import
(
LLM
,
AsyncLLMEngine
,
SamplingParams
,
RequestOutput
,
TokenOutput
,
)
__all__
=
[
"AutoLlamaModel"
,
"distributed"
,
"cache"
,
"llm"
,
# LLM classes
"LLM"
,
"AsyncLLMEngine"
,
"SamplingParams"
,
"RequestOutput"
,
"TokenOutput"
,
]
python/infinilm/llm/__init__.py
0 → 100644
View file @
c73ff203
"""
InfiniLM Engine - High-performance llm inference engine with batch generation and streaming support.
"""
from
infinilm.llm.sampling_params
import
SamplingParams
from
infinilm.llm.request
import
(
RequestStatus
,
FinishReason
,
RequestOutput
,
CompletionOutput
,
TokenOutput
,
InferenceRequest
,
)
from
infinilm.llm.llm
import
(
LLM
,
LLMEngine
,
AsyncLLMEngine
,
EngineConfig
,
)
from
infinilm.llm.scheduler
import
Scheduler
,
SchedulerOutput
from
infinilm.llm.cache_manager
import
BlockManager
,
Block
__all__
=
[
# Main classes
"LLM"
,
"AsyncLLMEngine"
,
"LLMEngine"
,
"EngineConfig"
,
# Parameters
"SamplingParams"
,
# Request and Output
"InferenceRequest"
,
"RequestOutput"
,
"CompletionOutput"
,
"TokenOutput"
,
"RequestStatus"
,
"FinishReason"
,
# Internal (for advanced use)
"Scheduler"
,
"SchedulerOutput"
,
"BlockManager"
,
"Block"
,
]
python/infinilm/llm/cache_manager.py
0 → 100644
View file @
c73ff203
"""
KV Cache Manager - Paged Attention block-based cache allocation and management.
"""
from
collections
import
deque
from
typing
import
List
,
Dict
,
Set
import
xxhash
import
numpy
as
np
class
Block
:
"""KV Cache Block with reference counting and hash-based reuse support."""
def
__init__
(
self
,
block_id
:
int
):
self
.
block_id
=
block_id
self
.
ref_count
=
0
self
.
hash
=
-
1
self
.
token_ids
:
List
[
int
]
=
[]
def
update
(
self
,
hash_value
:
int
,
token_ids
:
List
[
int
])
->
None
:
self
.
hash
=
hash_value
self
.
token_ids
=
token_ids
.
copy
()
def
reset
(
self
)
->
None
:
self
.
ref_count
=
1
self
.
hash
=
-
1
self
.
token_ids
=
[]
def
free
(
self
)
->
None
:
self
.
ref_count
=
0
self
.
hash
=
-
1
self
.
token_ids
=
[]
def
__repr__
(
self
)
->
str
:
return
f
"Block(id=
{
self
.
block_id
}
, ref=
{
self
.
ref_count
}
, hash=
{
self
.
hash
}
)"
class
BlockManager
:
"""Manages Paged KV Cache allocation with prefix caching support.
Features:
- Block allocation/deallocation with reference counting
- Hash-based prefix caching for token sequence reuse
- Slot mapping generation for physical-to-logical position mapping
"""
def
__init__
(
self
,
num_blocks
:
int
,
block_size
:
int
):
assert
(
num_blocks
>
0
and
block_size
>
0
),
"num_blocks and block_size must be positive"
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
blocks
:
List
[
Block
]
=
[
Block
(
i
)
for
i
in
range
(
num_blocks
)]
self
.
hash_to_block_id
:
Dict
[
int
,
int
]
=
{}
self
.
free_block_ids
:
deque
=
deque
(
range
(
num_blocks
))
self
.
used_block_ids
:
Set
[
int
]
=
set
()
self
.
req_block_ids
:
Set
[
int
]
=
set
()
def
reset_req_blocks
(
self
)
->
None
:
"""Move blocks from prefill stage to used blocks and update hash mappings."""
for
block_id
in
self
.
req_block_ids
:
self
.
used_block_ids
.
add
(
block_id
)
block
=
self
.
blocks
[
block_id
]
prefix_hash
=
block
.
hash
self
.
hash_to_block_id
[
prefix_hash
]
=
block_id
self
.
req_block_ids
.
clear
()
@
classmethod
def
compute_hash
(
cls
,
token_ids
:
List
[
int
],
prefix_hash
:
int
=
-
1
)
->
int
:
"""Compute hash for token sequence with optional prefix chaining."""
h
=
xxhash
.
xxh64
()
if
prefix_hash
!=
-
1
:
h
.
update
(
prefix_hash
.
to_bytes
(
8
,
"little"
))
h
.
update
(
np
.
array
(
token_ids
,
dtype
=
np
.
int32
).
tobytes
())
return
h
.
intdigest
()
def
_allocate_partial_block
(
self
,
block_id
:
int
)
->
Block
:
"""Allocate an incomplete block and add to used blocks."""
assert
block_id
in
self
.
free_block_ids
,
f
"Block
{
block_id
}
not in free list"
block
=
self
.
blocks
[
block_id
]
assert
block
.
ref_count
==
0
,
f
"Block
{
block_id
}
ref_count not zero"
block
.
reset
()
self
.
free_block_ids
.
remove
(
block_id
)
self
.
used_block_ids
.
add
(
block_id
)
return
block
def
_allocate_full_block
(
self
,
block_id
:
int
)
->
Block
:
"""Allocate a complete block and add to request blocks."""
assert
block_id
in
self
.
free_block_ids
,
f
"Block
{
block_id
}
not in free list"
block
=
self
.
blocks
[
block_id
]
assert
block
.
ref_count
==
0
,
f
"Block
{
block_id
}
ref_count not zero"
block
.
reset
()
self
.
free_block_ids
.
remove
(
block_id
)
self
.
req_block_ids
.
add
(
block_id
)
return
block
def
_deallocate_block
(
self
,
block_id
:
int
):
"""Deallocate a block and return it to free list."""
block
=
self
.
blocks
[
block_id
]
assert
(
block
.
ref_count
==
0
),
f
"Block
{
block_id
}
ref_count not zero, cannot deallocate"
if
block
.
hash
!=
-
1
and
self
.
hash_to_block_id
.
get
(
block
.
hash
)
==
block_id
:
del
self
.
hash_to_block_id
[
block
.
hash
]
block
.
free
()
self
.
used_block_ids
.
remove
(
block_id
)
self
.
free_block_ids
.
append
(
block_id
)
def
can_allocate
(
self
,
num_required_blocks
:
int
)
->
bool
:
return
len
(
self
.
free_block_ids
)
>=
num_required_blocks
def
allocate_blocks
(
self
,
token_ids
:
List
[
int
],
block_table
:
List
[
int
]
=
None
)
->
tuple
[
List
[
int
],
List
[
int
],
int
]:
"""Allocate cache blocks for new request with prefix caching support.
Args:
token_ids: Input token sequence
block_table: Existing block_table (for decode phase)
Returns:
Tuple of (block_table, slot_mapping, num_cached_tokens)
"""
if
block_table
is
None
:
block_table
=
[]
num_tokens
=
len
(
token_ids
)
num_blocks
=
(
num_tokens
+
self
.
block_size
-
1
)
//
self
.
block_size
slot_mapping
=
[]
num_cached_tokens
=
0
prefix_hash
=
-
1
cache_miss
=
False
for
block_idx
in
range
(
num_blocks
):
start_idx
=
block_idx
*
self
.
block_size
end_idx
=
min
(
start_idx
+
self
.
block_size
,
num_tokens
)
block_tokens
=
token_ids
[
start_idx
:
end_idx
]
# Only full blocks can be hashed for reuse
if
len
(
block_tokens
)
==
self
.
block_size
:
prefix_hash
=
self
.
compute_hash
(
block_tokens
,
prefix_hash
)
# Try to reuse existing block
if
not
cache_miss
:
cached_block_id
=
self
.
hash_to_block_id
.
get
(
prefix_hash
,
-
1
)
if
(
cached_block_id
!=
-
1
and
self
.
blocks
[
cached_block_id
].
token_ids
==
block_tokens
):
# Check if all tokens are cached
if
num_cached_tokens
+
self
.
block_size
==
len
(
token_ids
):
cache_miss
=
True
else
:
# Reuse successful
block
=
self
.
blocks
[
cached_block_id
]
block
.
ref_count
+=
1
block_table
.
append
(
cached_block_id
)
num_cached_tokens
+=
self
.
block_size
continue
else
:
cache_miss
=
True
else
:
prefix_hash
=
-
1
# Cannot reuse, allocate new block
if
not
self
.
free_block_ids
:
raise
RuntimeError
(
"No available cache blocks"
)
new_block_id
=
self
.
free_block_ids
[
0
]
if
prefix_hash
!=
-
1
:
block
=
self
.
_allocate_full_block
(
new_block_id
)
block
.
update
(
prefix_hash
,
block_tokens
)
else
:
block
=
self
.
_allocate_partial_block
(
new_block_id
)
block_table
.
append
(
new_block_id
)
# Generate slot_mapping
for
i
in
range
(
len
(
block_tokens
)):
slot_mapping
.
append
(
new_block_id
*
self
.
block_size
+
i
)
return
block_table
,
slot_mapping
,
num_cached_tokens
def
append_slot
(
self
,
block_table
:
List
[
int
],
num_tokens
:
int
,
total_token_ids
:
List
[
int
]
=
None
)
->
tuple
[
List
[
int
],
int
]:
"""Append slot for decode phase (generate one new token).
Args:
block_table: Current block_table
num_tokens: Current total token count (including newly generated token)
total_token_ids: All token sequence (for updating block hash)
Returns:
Tuple of (block_table, slot_id)
"""
assert
len
(
block_table
)
>
0
,
"block_table cannot be empty"
assert
num_tokens
>
0
,
"num_tokens must be greater than 0"
if
num_tokens
%
self
.
block_size
==
1
:
# Previous block is full, update its hash for future prefix caching
last_block_id
=
block_table
[
-
1
]
last_block
=
self
.
blocks
[
last_block_id
]
# Only update if block's token_ids is empty (avoid duplicate updates)
if
len
(
last_block
.
token_ids
)
==
0
:
block_start_idx
=
num_tokens
-
self
.
block_size
-
1
block_end_idx
=
num_tokens
-
1
block_tokens
=
total_token_ids
[
block_start_idx
:
block_end_idx
]
# Compute prefix_hash using previous block's hash if available
if
len
(
block_table
)
>
1
:
prev_block
=
self
.
blocks
[
block_table
[
-
2
]]
prefix_hash
=
prev_block
.
hash
else
:
prefix_hash
=
-
1
current_hash
=
self
.
compute_hash
(
block_tokens
,
prefix_hash
)
last_block
.
update
(
current_hash
,
block_tokens
)
self
.
hash_to_block_id
[
current_hash
]
=
last_block_id
# Need new block
if
not
self
.
free_block_ids
:
if
not
self
.
try_free_blocks
(
1
):
raise
RuntimeError
(
"No available cache blocks"
)
new_block_id
=
self
.
free_block_ids
[
0
]
self
.
_allocate_partial_block
(
new_block_id
)
block_table
.
append
(
new_block_id
)
# Calculate slot
last_block_id
=
block_table
[
-
1
]
offset
=
(
num_tokens
-
1
)
%
self
.
block_size
slot_id
=
last_block_id
*
self
.
block_size
+
offset
return
block_table
,
slot_id
def
free_blocks
(
self
,
block_table
:
List
[
int
]):
"""Decrease reference count for all blocks. Blocks with ref_count=0 are not
immediately freed to allow reuse."""
for
block_id
in
reversed
(
block_table
):
block
=
self
.
blocks
[
block_id
]
block
.
ref_count
-=
1
def
try_free_blocks
(
self
,
num_required
:
int
)
->
bool
:
"""Try to free blocks with ref_count=0."""
to_free
=
[
bid
for
bid
in
self
.
used_block_ids
if
self
.
blocks
[
bid
].
ref_count
==
0
]
for
block_id
in
to_free
:
self
.
_deallocate_block
(
block_id
)
if
self
.
can_allocate
(
num_required
):
return
True
return
self
.
can_allocate
(
num_required
)
def
get_num_free_blocks
(
self
)
->
int
:
return
len
(
self
.
free_block_ids
)
def
__repr__
(
self
):
return
(
f
"BlockManager(blocks=
{
self
.
num_blocks
}
, block_size=
{
self
.
block_size
}
, "
f
"free=
{
len
(
self
.
free_block_ids
)
}
, used=
{
len
(
self
.
used_block_ids
)
}
)"
)
python/infinilm/llm/llm.py
0 → 100644
View file @
c73ff203
This diff is collapsed.
Click to expand it.
python/infinilm/llm/request.py
0 → 100644
View file @
c73ff203
"""
Request and Output - Data structures for inference requests and outputs.
"""
from
enum
import
Enum
from
dataclasses
import
dataclass
,
field
from
typing
import
List
,
Optional
,
Any
import
time
import
janus
from
infinilm.llm.sampling_params
import
SamplingParams
class
RequestStatus
(
Enum
):
"""Status of an inference request."""
WAITING
=
"waiting"
RUNNING
=
"running"
FINISHED
=
"finished"
CANCELED
=
"canceled"
FAILED
=
"failed"
TIMEOUT
=
"timeout"
class
FinishReason
(
Enum
):
"""Reason for finishing generation."""
STOP
=
"stop"
LENGTH
=
"length"
EOS_TOKEN
=
"eos_token"
STOP_STRING
=
"stop_string"
TIMEOUT
=
"timeout"
CANCELED
=
"canceled"
ERROR
=
"error"
@
dataclass
class
RequestOutput
:
"""Output from a single generation request.
Attributes:
request_id: Unique identifier for the request.
prompt: Original prompt text.
prompt_token_ids: Token IDs of the prompt.
outputs: List of generated outputs (for beam search, multiple outputs possible).
finished: Whether generation is complete.
finish_reason: Reason for finishing.
"""
request_id
:
str
prompt
:
Optional
[
str
]
=
None
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
outputs
:
List
[
"CompletionOutput"
]
=
field
(
default_factory
=
list
)
finished
:
bool
=
False
finish_reason
:
Optional
[
FinishReason
]
=
None
@
dataclass
class
CompletionOutput
:
"""Single completion output.
Attributes:
index: Index of this output (for beam search).
text: Generated text.
token_ids: Generated token IDs.
finish_reason: Reason for finishing.
"""
index
:
int
=
0
text
:
str
=
""
token_ids
:
List
[
int
]
=
field
(
default_factory
=
list
)
finish_reason
:
Optional
[
FinishReason
]
=
None
@
dataclass
class
TokenOutput
:
"""Output for a single generated token.
Attributes:
request_id: Unique identifier for the request.
token_id: Generated token ID.
token_text: Decoded text of the token.
finished: Whether generation is complete.
finish_reason: Reason for finishing.
generated_text: Full generated text so far.
"""
request_id
:
str
token_id
:
int
token_text
:
str
finished
:
bool
=
False
finish_reason
:
Optional
[
FinishReason
]
=
None
generated_text
:
str
=
""
class
InferenceRequest
:
"""Internal inference request object for managing generation state and resources."""
def
__init__
(
self
,
request_id
:
str
,
prompt
:
Optional
[
str
]
=
None
,
prompt_token_ids
:
Optional
[
List
[
int
]]
=
None
,
sampling_params
:
Optional
[
SamplingParams
]
=
None
,
eos_token_ids
:
Optional
[
List
[
int
]]
=
None
,
arrival_time
:
Optional
[
float
]
=
None
,
# For server use
request_data
:
Optional
[
dict
]
=
None
,
http_request
:
Optional
[
Any
]
=
None
,
):
# Request metadata
self
.
request_id
:
str
=
request_id
self
.
prompt
:
Optional
[
str
]
=
prompt
self
.
prompt_token_ids
:
List
[
int
]
=
prompt_token_ids
or
[]
self
.
prompt_length
:
int
=
len
(
self
.
prompt_token_ids
)
self
.
arrival_time
:
float
=
arrival_time
or
time
.
time
()
self
.
finished_time
:
Optional
[
float
]
=
None
# Sampling parameters
self
.
sampling_params
:
SamplingParams
=
sampling_params
or
SamplingParams
()
# EOS token IDs (from model config)
self
.
eos_token_ids
:
List
[
int
]
=
eos_token_ids
or
[]
# Generation state
self
.
generated_token_ids
:
List
[
int
]
=
[]
self
.
generated_text
:
str
=
""
self
.
is_prefill
:
bool
=
True
self
.
status
:
RequestStatus
=
RequestStatus
.
WAITING
self
.
finish_reason
:
Optional
[
FinishReason
]
=
None
self
.
priority
:
int
=
0
# KV cache management
self
.
cache_id
:
Optional
[
int
]
=
None
self
.
block_table
:
List
[
int
]
=
[]
self
.
slot_mapping
:
List
[
int
]
=
[]
self
.
num_cached_tokens
:
int
=
0
self
.
num_blocks
:
int
=
0
# For server use
self
.
request_data
:
Optional
[
dict
]
=
request_data
self
.
http_request
:
Optional
[
Any
]
=
http_request
# Output management (for async streaming)
self
.
_output_queue
:
Optional
[
janus
.
Queue
]
=
None
@
property
def
output_queue
(
self
)
->
janus
.
Queue
:
"""Lazy initialization of output queue."""
if
self
.
_output_queue
is
None
:
self
.
_output_queue
=
janus
.
Queue
()
return
self
.
_output_queue
def
get_prompt_length
(
self
)
->
int
:
return
self
.
prompt_length
def
get_input_tokens
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
def
get_num_generated_tokens
(
self
)
->
int
:
return
len
(
self
.
generated_token_ids
)
def
get_total_length
(
self
)
->
int
:
return
self
.
prompt_length
+
len
(
self
.
generated_token_ids
)
def
get_all_token_ids
(
self
)
->
List
[
int
]:
return
self
.
prompt_token_ids
+
self
.
generated_token_ids
def
get_num_blocks_required
(
self
,
block_size
:
int
)
->
int
:
total_tokens
=
self
.
get_total_length
()
return
(
total_tokens
+
block_size
-
1
)
//
block_size
def
get_max_tokens
(
self
)
->
Optional
[
int
]:
return
self
.
sampling_params
.
max_tokens
def
is_finished
(
self
)
->
bool
:
return
self
.
status
in
[
RequestStatus
.
FINISHED
,
RequestStatus
.
CANCELED
,
RequestStatus
.
FAILED
,
RequestStatus
.
TIMEOUT
,
]
def
mark_finished
(
self
,
reason
:
FinishReason
):
"""Mark the request as finished with the given reason."""
self
.
status
=
RequestStatus
.
FINISHED
self
.
finish_reason
=
reason
self
.
finished_time
=
time
.
time
()
def
mark_failed
(
self
,
reason
:
FinishReason
=
FinishReason
.
ERROR
):
"""Mark the request as failed."""
self
.
status
=
RequestStatus
.
FAILED
self
.
finish_reason
=
reason
self
.
finished_time
=
time
.
time
()
def
mark_canceled
(
self
):
"""Mark the request as canceled."""
self
.
status
=
RequestStatus
.
CANCELED
self
.
finish_reason
=
FinishReason
.
CANCELED
self
.
finished_time
=
time
.
time
()
def
mark_timeout
(
self
):
"""Mark the request as timed out."""
self
.
status
=
RequestStatus
.
TIMEOUT
self
.
finish_reason
=
FinishReason
.
TIMEOUT
self
.
finished_time
=
time
.
time
()
async
def
close
(
self
):
"""Close the output queue and clean up resources."""
if
self
.
_output_queue
is
not
None
:
await
self
.
_output_queue
.
async_q
.
join
()
self
.
_output_queue
.
close
()
await
self
.
_output_queue
.
wait_closed
()
def
to_request_output
(
self
)
->
RequestOutput
:
"""Convert to RequestOutput for external use."""
return
RequestOutput
(
request_id
=
self
.
request_id
,
prompt
=
self
.
prompt
,
prompt_token_ids
=
self
.
prompt_token_ids
,
outputs
=
[
CompletionOutput
(
index
=
0
,
text
=
self
.
generated_text
,
token_ids
=
self
.
generated_token_ids
.
copy
(),
finish_reason
=
self
.
finish_reason
,
)
],
finished
=
self
.
is_finished
(),
finish_reason
=
self
.
finish_reason
,
)
python/infinilm/llm/sampling_params.py
0 → 100644
View file @
c73ff203
"""
Sampling Parameters - Configuration for text generation sampling.
"""
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
@
dataclass
class
SamplingParams
:
"""Sampling parameters for text generation."""
temperature
:
float
=
1.0
top_p
:
float
=
0.8
top_k
:
int
=
1
max_tokens
:
Optional
[
int
]
=
None
stop
:
Optional
[
List
[
str
]]
=
None
stop_token_ids
:
Optional
[
List
[
int
]]
=
None
def
__post_init__
(
self
):
if
self
.
stop
is
None
:
self
.
stop
=
[]
if
self
.
stop_token_ids
is
None
:
self
.
stop_token_ids
=
[]
def
clone
(
self
)
->
"SamplingParams"
:
"""Create a copy of this SamplingParams instance."""
return
SamplingParams
(
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
max_tokens
=
self
.
max_tokens
,
stop
=
self
.
stop
.
copy
()
if
self
.
stop
else
None
,
stop_token_ids
=
self
.
stop_token_ids
.
copy
()
if
self
.
stop_token_ids
else
None
,
)
python/infinilm/llm/scheduler.py
0 → 100644
View file @
c73ff203
"""
Scheduler - Request scheduling and batch management with Paged Attention KV Cache.
"""
import
queue
import
janus
import
logging
from
typing
import
List
,
Optional
from
infinilm.llm.request
import
RequestStatus
,
InferenceRequest
from
infinilm.llm.cache_manager
import
BlockManager
logger
=
logging
.
getLogger
(
__name__
)
class
SchedulerOutput
:
"""Scheduler output containing scheduled requests and execution phase info."""
def
__init__
(
self
,
scheduled_requests
:
List
[
InferenceRequest
],
is_prefill
:
bool
=
False
,
):
self
.
scheduled_requests
=
scheduled_requests
self
.
num_requests
=
len
(
scheduled_requests
)
self
.
is_prefill
=
is_prefill
def
build_model_inputs
(
self
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
):
"""Construct model inputs for prefill or decode phase.
Prefill phase:
- input_ids: Flattened token list (excluding cached tokens)
- position_ids: Position IDs for new tokens in complete sequence
- past_kv_lengths: Number of cached tokens per request
- total_kv_lengths: Total tokens (cached + new) per request
- input_offsets: Start position of each request in flattened array
- block_tables: Padded block_table for each request
- slot_mapping: Token to slot mappings
Decode phase:
- input_ids: Only last generated token per request
- position_ids: Position of last token in complete sequence
- past_kv_lengths: Number of cached tokens per request
- total_kv_lengths: Total sequence length per request
- input_offsets: Offsets for each request
- block_tables: Padded block_table for each request
- slot_mapping: Single slot per request
"""
if
not
self
.
scheduled_requests
:
raise
RuntimeError
(
"build_model_inputs called with empty scheduled_requests"
)
tokens
=
[]
seq_lens
=
[]
seq_offsets
=
[
0
]
block_tables
=
[]
slot_mapping
=
[]
cached_lens
=
[]
position_ids
=
[]
max_block_table_len
=
max
(
len
(
req
.
block_table
)
for
req
in
self
.
scheduled_requests
)
current_offset
=
0
for
req
in
self
.
scheduled_requests
:
num_cached
=
req
.
num_cached_tokens
if
self
.
is_prefill
:
# Prefill phase
req_tokens
=
req
.
get_input_tokens
()
tokens_to_compute
=
req_tokens
[
num_cached
:]
tokens
.
extend
(
tokens_to_compute
)
seq_len
=
len
(
tokens_to_compute
)
seq_lens
.
append
(
len
(
req_tokens
))
current_offset
+=
seq_len
seq_offsets
.
append
(
current_offset
)
slot_mapping
.
extend
(
req
.
slot_mapping
)
cached_lens
.
append
(
num_cached
)
position_ids
.
extend
(
range
(
num_cached
,
num_cached
+
seq_len
))
else
:
# Decode phase
last_token
=
req
.
generated_token_ids
[
-
1
]
tokens
.
append
(
last_token
)
seq_lens
.
append
(
req
.
get_total_length
())
current_offset
+=
1
seq_offsets
.
append
(
current_offset
)
slot_mapping
.
extend
(
req
.
slot_mapping
)
cached_lens
.
append
(
num_cached
)
position_ids
.
append
(
req
.
get_total_length
()
-
1
)
# Pad block_table to same length
padded_block_table
=
req
.
block_table
+
[
-
1
]
*
(
max_block_table_len
-
len
(
req
.
block_table
)
)
block_tables
.
append
(
padded_block_table
)
return
{
"input_ids"
:
tokens
,
"position_ids"
:
position_ids
,
"past_kv_lengths"
:
cached_lens
,
"total_kv_lengths"
:
seq_lens
,
"input_offsets"
:
seq_offsets
,
"block_tables"
:
block_tables
,
"slot_mapping"
:
slot_mapping
,
"temperature"
:
temperature
,
"top_k"
:
top_k
,
"top_p"
:
top_p
,
}
class
Scheduler
:
"""Request scheduler with integrated BlockManager for KV cache management.
Scheduling logic:
1. Running queue: Check for new blocks needed, update slot_mapping
2. Waiting queue: Try block reuse (prefix caching), allocate new blocks
3. Reference counting: Free blocks when requests complete
"""
def
__init__
(
self
,
max_batch_size
:
int
=
16
,
num_blocks
:
int
=
8
*
1024
,
block_size
:
int
=
16
,
):
self
.
waiting_queue
=
janus
.
Queue
()
self
.
running_queue
=
janus
.
Queue
()
self
.
max_batch_size
=
max_batch_size
self
.
cache_manager
=
BlockManager
(
num_blocks
=
num_blocks
,
block_size
=
block_size
)
self
.
block_size
=
block_size
def
add_request
(
self
,
request
:
InferenceRequest
):
if
request
is
not
None
:
request
.
status
=
RequestStatus
.
WAITING
self
.
waiting_queue
.
sync_q
.
put
(
request
)
def
schedule
(
self
)
->
Optional
[
SchedulerOutput
]:
"""Schedule and return batch of requests to execute."""
scheduled_requests
=
[]
is_prefill
=
False
# Process Waiting queue (prefill phase)
while
len
(
scheduled_requests
)
<
self
.
max_batch_size
:
try
:
req
=
self
.
waiting_queue
.
sync_q
.
get_nowait
()
except
queue
.
Empty
:
break
req_tokens
=
req
.
get_input_tokens
()
num_required_blocks
=
req
.
get_num_blocks_required
(
self
.
block_size
)
if
not
self
.
cache_manager
.
can_allocate
(
num_required_blocks
):
if
not
self
.
cache_manager
.
try_free_blocks
(
num_required_blocks
):
raise
RuntimeError
(
"No available cache blocks"
)
# Allocate blocks with automatic prefix caching support
req
.
block_table
,
req
.
slot_mapping
,
req
.
num_cached_tokens
=
(
self
.
cache_manager
.
allocate_blocks
(
req_tokens
,
req
.
block_table
)
)
req
.
num_blocks
=
len
(
req
.
block_table
)
req
.
status
=
RequestStatus
.
RUNNING
scheduled_requests
.
append
(
req
)
# Return prefill batch if any waiting requests were scheduled
if
scheduled_requests
:
is_prefill
=
True
return
SchedulerOutput
(
scheduled_requests
=
scheduled_requests
,
is_prefill
=
is_prefill
,
)
# Process Running queue (decode phase)
while
len
(
scheduled_requests
)
<
self
.
max_batch_size
:
try
:
req
=
self
.
running_queue
.
sync_q
.
get_nowait
()
except
queue
.
Empty
:
break
# Decode phase: allocate slot for newly generated token
try
:
req
.
block_table
,
new_slot
=
self
.
cache_manager
.
append_slot
(
req
.
block_table
,
req
.
get_total_length
(),
req
.
get_all_token_ids
()
)
req
.
slot_mapping
=
[
new_slot
]
req
.
num_blocks
=
len
(
req
.
block_table
)
req
.
num_cached_tokens
=
req
.
get_total_length
()
-
1
scheduled_requests
.
append
(
req
)
except
RuntimeError
as
e
:
raise
RuntimeError
(
"No available cache blocks"
)
from
e
# Return decode batch if any running requests were scheduled
if
scheduled_requests
:
is_prefill
=
False
return
SchedulerOutput
(
scheduled_requests
=
scheduled_requests
,
is_prefill
=
is_prefill
,
)
return
None
def
complete_requests
(
self
,
requests
:
List
[
InferenceRequest
]):
"""Handle completed requests and free their blocks."""
for
req
in
requests
:
if
req
.
status
in
[
RequestStatus
.
FINISHED
,
RequestStatus
.
CANCELED
,
RequestStatus
.
FAILED
,
RequestStatus
.
TIMEOUT
,
]:
if
req
.
block_table
:
self
.
cache_manager
.
free_blocks
(
req
.
block_table
)
if
req
.
status
==
RequestStatus
.
CANCELED
:
logger
.
info
(
f
"Request
{
req
.
request_id
[:
8
]
}
... canceled:
{
req
.
finish_reason
}
"
)
elif
req
.
status
==
RequestStatus
.
FAILED
:
logger
.
error
(
f
"Request
{
req
.
request_id
[:
8
]
}
... failed:
{
req
.
finish_reason
}
"
)
elif
req
.
status
==
RequestStatus
.
TIMEOUT
:
logger
.
error
(
f
"Request
{
req
.
request_id
[:
8
]
}
... timed out:
{
req
.
finish_reason
}
"
)
else
:
# Still running, put back in running queue
self
.
running_queue
.
sync_q
.
put
(
req
)
def
get_cache_stats
(
self
)
->
dict
:
"""Get cache statistics."""
return
{
"num_blocks"
:
self
.
cache_manager
.
num_blocks
,
"block_size"
:
self
.
cache_manager
.
block_size
,
"num_free_blocks"
:
self
.
cache_manager
.
get_num_free_blocks
(),
"num_req_blocks"
:
len
(
self
.
cache_manager
.
req_block_ids
),
"num_used_blocks"
:
len
(
self
.
cache_manager
.
used_block_ids
),
}
python/infinilm/server/inference_server.py
0 → 100644
View file @
c73ff203
"""
Inference Server - HTTP API server for LLM inference.
"""
from
contextlib
import
asynccontextmanager
import
sys
import
time
import
json
import
uuid
import
argparse
import
uvicorn
import
logging
from
fastapi
import
FastAPI
,
Request
from
fastapi.responses
import
JSONResponse
,
StreamingResponse
from
infinilm.llm
import
AsyncLLMEngine
,
SamplingParams
,
FinishReason
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_STREAM_TIMEOUT
=
100.0
DEFAULT_REQUEST_TIMEOUT
=
1000.0
def
chunk_json
(
id_
,
content
=
None
,
role
=
None
,
finish_reason
=
None
):
"""Generate JSON chunk for streaming response."""
delta
=
{}
if
content
:
delta
[
"content"
]
=
content
if
role
:
delta
[
"role"
]
=
role
return
{
"id"
:
id_
,
"object"
:
"chat.completion.chunk"
,
"created"
:
int
(
time
.
time
()),
"model"
:
"jiuge"
,
"system_fingerprint"
:
None
,
"choices"
:
[
{
"index"
:
0
,
"text"
:
content
,
"delta"
:
delta
,
"logprobs"
:
None
,
"finish_reason"
:
finish_reason
,
}
],
}
class
InferenceServer
:
"""HTTP server for LLM inference."""
def
__init__
(
self
,
model_path
:
str
,
device
:
str
=
"cuda"
,
dtype
:
str
=
"float16"
,
tensor_parallel_size
:
int
=
1
,
max_tokens
:
int
=
4096
,
max_batch_size
:
int
=
16
,
num_blocks
:
int
=
8
*
1024
,
block_size
:
int
=
16
,
temperature
:
float
=
1.0
,
top_p
:
float
=
0.8
,
top_k
:
int
=
1
,
host
:
str
=
"0.0.0.0"
,
port
:
int
=
8000
,
):
"""Initialize inference server.
Args:
model_path: Path to the model directory.
device: Device type ('cpu', 'cuda', 'mlu', 'moore').
dtype: Data type ('float16', 'bfloat16', 'float32').
tensor_parallel_size: Number of devices for tensor parallelism.
max_tokens: Default maximum tokens to generate.
max_batch_size: Maximum batch size for inference.
num_blocks: Number of KV cache blocks.
block_size: Size of each KV cache block.
temperature: Default sampling temperature.
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
host: Server host address.
port: Server port number.
"""
self
.
model_path
=
model_path
self
.
device
=
device
self
.
dtype
=
dtype
self
.
tensor_parallel_size
=
tensor_parallel_size
self
.
max_tokens
=
max_tokens
self
.
max_batch_size
=
max_batch_size
self
.
num_blocks
=
num_blocks
self
.
block_size
=
block_size
self
.
temperature
=
temperature
self
.
top_p
=
top_p
self
.
top_k
=
top_k
self
.
host
=
host
self
.
port
=
port
self
.
engine
:
AsyncLLMEngine
=
None
def
start
(
self
):
"""Start the HTTP server."""
app
=
self
.
_create_app
()
logger
.
info
(
f
"Starting API Server at
{
self
.
host
}
:
{
self
.
port
}
..."
)
uvicorn
.
run
(
app
,
host
=
self
.
host
,
port
=
self
.
port
)
logger
.
info
(
"Inference Server stopped"
)
def
_create_app
(
self
):
"""Create FastAPI application."""
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
self
.
engine
=
AsyncLLMEngine
(
model_path
=
self
.
model_path
,
device
=
self
.
device
,
dtype
=
self
.
dtype
,
tensor_parallel_size
=
self
.
tensor_parallel_size
,
max_batch_size
=
self
.
max_batch_size
,
max_tokens
=
self
.
max_tokens
,
num_blocks
=
self
.
num_blocks
,
block_size
=
self
.
block_size
,
temperature
=
self
.
temperature
,
top_p
=
self
.
top_p
,
top_k
=
self
.
top_k
,
)
self
.
engine
.
start
()
logger
.
info
(
f
"Engine initialized with model at
{
self
.
model_path
}
"
)
yield
self
.
engine
.
stop
()
app
=
FastAPI
(
lifespan
=
lifespan
)
self
.
_register_routes
(
app
)
return
app
def
_register_routes
(
self
,
app
:
FastAPI
):
"""Register API routes."""
@
app
.
post
(
"/chat/completions"
)
async
def
chat_completions
(
request
:
Request
):
try
:
data
=
await
request
.
json
()
logger
.
debug
(
f
"Received request data:
{
data
}
"
)
except
Exception
as
e
:
logger
.
error
(
f
"Failed to parse request JSON:
{
e
}
"
)
return
JSONResponse
(
content
=
{
"error"
:
"Invalid JSON"
},
status_code
=
400
)
if
not
data
.
get
(
"messages"
):
if
not
data
.
get
(
"prompt"
):
return
JSONResponse
(
content
=
{
"error"
:
"No message provided"
},
status_code
=
400
)
else
:
data
[
"messages"
]
=
[{
"role"
:
"user"
,
"content"
:
data
.
get
(
"prompt"
)}]
stream
=
data
.
get
(
"stream"
,
False
)
request_id
=
f
"cmpl-
{
uuid
.
uuid4
().
hex
}
"
if
stream
:
return
StreamingResponse
(
self
.
_stream_chat
(
request_id
,
data
,
request
),
media_type
=
"text/event-stream"
,
)
else
:
response
=
await
self
.
_chat
(
request_id
,
data
,
request
)
if
isinstance
(
response
,
JSONResponse
):
return
response
return
JSONResponse
(
content
=
response
)
@
app
.
get
(
"/health"
)
async
def
health
():
return
{
"status"
:
"healthy"
}
@
app
.
get
(
"/v1/models"
)
async
def
list_models
():
return
{
"object"
:
"list"
,
"data"
:
[
{
"id"
:
"jiuge"
,
"object"
:
"model"
,
"created"
:
int
(
time
.
time
()),
"owned_by"
:
"infinilm"
,
}
],
}
def
_build_sampling_params
(
self
,
data
:
dict
)
->
SamplingParams
:
"""Build SamplingParams from request data."""
return
SamplingParams
(
temperature
=
data
.
get
(
"temperature"
,
self
.
temperature
),
top_p
=
data
.
get
(
"top_p"
,
self
.
top_p
),
top_k
=
data
.
get
(
"top_k"
,
self
.
top_k
),
max_tokens
=
data
.
get
(
"max_tokens"
,
self
.
max_tokens
),
stop
=
data
.
get
(
"stop"
),
)
async
def
_stream_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
"""Handle streaming chat request."""
req
=
None
start_time
=
time
.
time
()
try
:
messages
=
data
.
get
(
"messages"
,
[])
sampling_params
=
self
.
_build_sampling_params
(
data
)
req
=
self
.
engine
.
add_chat_request
(
messages
=
messages
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
request_data
=
data
,
http_request
=
http_request
,
)
async
for
token_output
in
self
.
engine
.
stream_request
(
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
):
# Check timeout
if
time
.
time
()
-
start_time
>
DEFAULT_REQUEST_TIMEOUT
:
logger
.
warning
(
f
"Request
{
request_id
}
timed out after
{
DEFAULT_REQUEST_TIMEOUT
}
s"
)
req
.
mark_timeout
()
error_chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
"[Request timeout]"
,
finish_reason
=
"timeout"
,
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
error_chunk
}
\n\n
"
break
# Check client disconnect
if
await
http_request
.
is_disconnected
():
logger
.
info
(
f
"Client disconnected for request
{
request_id
}
"
)
req
.
mark_canceled
()
break
# Send token
chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
token_output
.
token_text
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
chunk
}
\n\n
"
if
token_output
.
finished
:
finish_reason
=
self
.
_convert_finish_reason
(
token_output
.
finish_reason
)
chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
finish_reason
=
finish_reason
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
chunk
}
\n\n
"
break
except
Exception
as
e
:
logger
.
error
(
f
"Stream error for
{
request_id
}
:
{
e
}
"
,
exc_info
=
True
)
if
req
:
req
.
mark_failed
()
error_chunk
=
json
.
dumps
(
chunk_json
(
request_id
,
content
=
f
"[Error:
{
str
(
e
)
}
]"
,
finish_reason
=
"error"
),
ensure_ascii
=
False
,
)
yield
f
"data:
{
error_chunk
}
\n\n
"
finally
:
if
req
and
not
req
.
is_finished
():
req
.
mark_canceled
()
if
req
:
await
req
.
close
()
yield
"data: [DONE]
\n\n
"
async
def
_chat
(
self
,
request_id
:
str
,
data
:
dict
,
http_request
:
Request
):
"""Handle non-streaming chat request."""
req
=
None
start_time
=
time
.
time
()
try
:
messages
=
data
.
get
(
"messages"
,
[])
sampling_params
=
self
.
_build_sampling_params
(
data
)
req
=
self
.
engine
.
add_chat_request
(
messages
=
messages
,
sampling_params
=
sampling_params
,
request_id
=
request_id
,
request_data
=
data
,
http_request
=
http_request
,
)
# Collect all generated tokens
output_text
=
""
async
for
token_output
in
self
.
engine
.
stream_request
(
req
,
timeout
=
DEFAULT_STREAM_TIMEOUT
):
# Check timeout
if
time
.
time
()
-
start_time
>
DEFAULT_REQUEST_TIMEOUT
:
logger
.
warning
(
f
"Request
{
request_id
}
timed out"
)
req
.
mark_timeout
()
break
# Check client disconnect
if
await
http_request
.
is_disconnected
():
logger
.
info
(
f
"Client disconnected for request
{
request_id
}
"
)
req
.
mark_canceled
()
break
output_text
+=
token_output
.
token_text
if
token_output
.
finished
:
break
output_text
=
output_text
.
strip
()
finish_reason
=
self
.
_convert_finish_reason
(
req
.
finish_reason
)
response
=
chunk_json
(
request_id
,
content
=
output_text
,
role
=
"assistant"
,
finish_reason
=
finish_reason
or
"stop"
,
)
return
response
except
Exception
as
e
:
logger
.
error
(
f
"Chat error for
{
request_id
}
:
{
e
}
"
,
exc_info
=
True
)
if
req
:
req
.
mark_failed
()
return
JSONResponse
(
content
=
{
"error"
:
str
(
e
)},
status_code
=
500
)
finally
:
if
req
and
not
req
.
is_finished
():
req
.
mark_canceled
()
if
req
:
await
req
.
close
()
def
_convert_finish_reason
(
self
,
reason
:
FinishReason
)
->
str
:
"""Convert FinishReason enum to string."""
if
reason
is
None
:
return
None
if
reason
in
(
FinishReason
.
EOS_TOKEN
,
FinishReason
.
STOP_STRING
):
return
"stop"
return
reason
.
value
def
setup_logging
(
log_level
:
str
=
"INFO"
):
"""Configure logging system with proper formatting and handlers."""
log_format
=
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
date_format
=
"%Y-%m-%d %H:%M:%S"
logging
.
basicConfig
(
level
=
getattr
(
logging
,
log_level
.
upper
(),
logging
.
INFO
),
format
=
log_format
,
datefmt
=
date_format
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
),
],
force
=
True
,
)
def
parse_args
():
"""Parse command line arguments."""
parser
=
argparse
.
ArgumentParser
(
description
=
"InfiniLM Inference Server"
)
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
required
=
True
,
help
=
"Path to model directory"
)
parser
.
add_argument
(
"--tp"
,
type
=
int
,
default
=
1
,
help
=
"Tensor parallelism degree"
)
parser
.
add_argument
(
"--max_tokens"
,
type
=
int
,
default
=
512
,
help
=
"Maximum number of tokens to generate"
,
)
parser
.
add_argument
(
"--max_batch_size"
,
type
=
int
,
default
=
8
,
help
=
"Maximum batch size"
)
parser
.
add_argument
(
"--num_blocks"
,
type
=
int
,
default
=
8
*
1024
,
help
=
"Number of blocks for KV cache"
)
parser
.
add_argument
(
"--block_size"
,
type
=
int
,
default
=
16
,
help
=
"Block size for KV cache"
)
parser
.
add_argument
(
"--dtype"
,
type
=
str
,
default
=
"float16"
,
choices
=
[
"float32"
,
"float16"
,
"bfloat16"
],
help
=
"Data type"
,
)
parser
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
help
=
"Sampling temperature"
)
parser
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.8
,
help
=
"Top-p sampling parameter"
)
parser
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
1
,
help
=
"Top-k sampling parameter"
)
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"0.0.0.0"
,
help
=
"Server host"
)
parser
.
add_argument
(
"--port"
,
type
=
int
,
default
=
8000
,
help
=
"Server port"
)
parser
.
add_argument
(
"--cpu"
,
action
=
"store_true"
,
help
=
"Use CPU"
)
parser
.
add_argument
(
"--nvidia"
,
action
=
"store_true"
,
help
=
"Use NVIDIA GPU"
)
parser
.
add_argument
(
"--metax"
,
action
=
"store_true"
,
help
=
"Use MetaX device"
)
parser
.
add_argument
(
"--moore"
,
action
=
"store_true"
,
help
=
"Use Moore device"
)
parser
.
add_argument
(
"--iluvatar"
,
action
=
"store_true"
,
help
=
"Use Iluvatar device"
)
parser
.
add_argument
(
"--cambricon"
,
action
=
"store_true"
,
help
=
"Use Cambricon device"
)
parser
.
add_argument
(
"--log_level"
,
type
=
str
,
default
=
"INFO"
,
choices
=
[
"DEBUG"
,
"INFO"
,
"WARNING"
,
"ERROR"
,
"CRITICAL"
],
help
=
"Logging level"
,
)
return
parser
.
parse_args
()
def
main
():
args
=
parse_args
()
setup_logging
(
args
.
log_level
)
if
args
.
cpu
:
device
=
"cpu"
elif
args
.
nvidia
:
device
=
"cuda"
elif
args
.
metax
:
device
=
"cuda"
elif
args
.
moore
:
device
=
"moore"
elif
args
.
iluvatar
:
device
=
"cuda"
elif
args
.
cambricon
:
device
=
"mlu"
else
:
print
(
"Usage: python infinilm.server.inference_server [--cpu | --nvidia | --metax | --moore | --iluvatar | --cambricon] "
"--model_path=<path/to/model_dir> --max_tokens=MAX_TOKENS --max_batch_size=MAX_BATCH_SIZE"
"
\n
"
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
)
sys
.
exit
(
1
)
server
=
InferenceServer
(
model_path
=
args
.
model_path
,
device
=
device
,
dtype
=
args
.
dtype
,
tensor_parallel_size
=
args
.
tp
,
max_tokens
=
args
.
max_tokens
,
max_batch_size
=
args
.
max_batch_size
,
num_blocks
=
args
.
num_blocks
,
block_size
=
args
.
block_size
,
temperature
=
args
.
temperature
,
top_p
=
args
.
top_p
,
top_k
=
args
.
top_k
,
host
=
args
.
host
,
port
=
args
.
port
,
)
server
.
start
()
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment