Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
c1a3ab29
Unverified
Commit
c1a3ab29
authored
Jan 14, 2026
by
Haojie Wang
Committed by
GitHub
Jan 14, 2026
Browse files
Merge pull request #173 from InfiniTensor/issue/168
Issue/168 InfiniLM接入paged attention接口
parents
96e53dbb
09ab8fa4
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
101 additions
and
37 deletions
+101
-37
python/infinilm/generation/utils.py
python/infinilm/generation/utils.py
+2
-0
python/infinilm/infer_engine.py
python/infinilm/infer_engine.py
+99
-37
No files found.
python/infinilm/generation/utils.py
View file @
c1a3ab29
...
...
@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype):
return
ctypes
.
c_int32
elif
infini_dtype
==
infinicore
.
float32
:
return
ctypes
.
c_float
elif
infini_dtype
==
infinicore
.
int64
:
return
ctypes
.
c_int64
else
:
raise
ValueError
(
f
"Unsupported py_dtype:
{
infini_dtype
}
"
)
...
...
python/infinilm/infer_engine.py
View file @
c1a3ab29
...
...
@@ -4,7 +4,7 @@ from dataclasses import dataclass
import
infinicore
from
infinilm.auto_config
import
AutoConfig
from
infinilm.cache
import
StaticKVCacheConfig
from
infinilm.cache
import
StaticKVCacheConfig
,
PagedKVCacheConfig
from
infinilm.distributed
import
DistConfig
from
infinilm.lib
import
_infinilm
...
...
@@ -18,6 +18,7 @@ class GenerationConfig:
top_p
:
float
=
1.0
eos_token_id
:
list
[
int
]
|
None
=
None
stop_on_eos
:
bool
=
True
class
InferEngine
(
_infinilm
.
InferEngine
):
...
...
@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine):
self
.
use_cache
=
False
self
.
enable_paged_attn
=
isinstance
(
cache_config
,
PagedKVCacheConfig
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
...
...
@@ -50,8 +53,8 @@ class InferEngine(_infinilm.InferEngine):
input_ids
,
*
,
position_ids
=
None
,
cache
_lengths
=
None
,
input
_lengths
=
None
,
past_kv
_lengths
=
None
,
total_kv
_lengths
=
None
,
input_offsets
=
None
,
block_tables
=
None
,
slot_mapping
=
None
,
...
...
@@ -62,8 +65,12 @@ class InferEngine(_infinilm.InferEngine):
# TODO: Remove `_underlying` and simplify the corresponding code.
input_ids
=
input_ids
.
_underlying
if
input_ids
is
not
None
else
None
position_ids
=
position_ids
.
_underlying
if
position_ids
is
not
None
else
None
cache_lengths
=
cache_lengths
.
_underlying
if
cache_lengths
is
not
None
else
None
input_lengths
=
input_lengths
.
_underlying
if
input_lengths
is
not
None
else
None
past_kv_lengths
=
(
past_kv_lengths
.
_underlying
if
past_kv_lengths
is
not
None
else
None
)
total_kv_lengths
=
(
total_kv_lengths
.
_underlying
if
past_kv_lengths
is
not
None
else
None
)
input_offsets
=
input_offsets
.
_underlying
if
input_offsets
is
not
None
else
None
block_tables
=
block_tables
.
_underlying
if
block_tables
is
not
None
else
None
slot_mapping
=
slot_mapping
.
_underlying
if
slot_mapping
is
not
None
else
None
...
...
@@ -74,8 +81,8 @@ class InferEngine(_infinilm.InferEngine):
super
().
Input
(
input_ids
,
position_ids
=
position_ids
,
cach
e_lengths
=
cache
_lengths
,
input
_lengths
=
input
_lengths
,
past_sequenc
e_lengths
=
past_kv
_lengths
,
total_sequence
_lengths
=
total_kv
_lengths
,
input_offsets
=
input_offsets
,
block_tables
=
block_tables
,
slot_mapping
=
slot_mapping
,
...
...
@@ -87,21 +94,24 @@ class InferEngine(_infinilm.InferEngine):
.
output_ids
)
def
generate
(
self
,
input_ids
,
generation_config
,
*
,
_measure_and_log_time
=
False
):
def
generate
(
self
,
input_ids
,
generation_config
,
*
,
_measure_and_log_time
=
False
,
paged_block_size
=
16
,
):
if
generation_config
.
eos_token_id
is
None
:
eos_token_id
=
self
.
config
.
eos_token_id
else
:
eos_token_id
=
generation_config
.
eos_token_id
# TODO: Remove the `to_numpy` calls and simplify the corresponding code.
batch_size
,
seq_len
=
input_ids
.
shape
[:
2
]
position_ids
=
infinicore
.
from_list
(
[
list
(
range
(
0
,
seq_len
))
for
_
in
range
(
batch_size
)],
dtype
=
infinicore
.
int64
)
cache_lengths
=
infinicore
.
from_list
([
0
],
dtype
=
infinicore
.
int64
)
past_seq_len
=
0
output_ids
=
[]
initial_batch_size
,
initial_seqlen
=
input_ids
.
shape
[:
2
]
seq_len
=
initial_seqlen
batch_size
=
initial_batch_size
if
batch_size
!=
1
and
generation_config
.
max_new_tokens
is
None
:
raise
ValueError
(
...
...
@@ -111,14 +121,75 @@ class InferEngine(_infinilm.InferEngine):
if
_measure_and_log_time
:
time_measurements
=
[]
for
_
in
range
(
0
,
generation_config
.
max_new_tokens
):
for
iter
in
range
(
0
,
generation_config
.
max_new_tokens
):
if
_measure_and_log_time
:
start_time
=
time
.
perf_counter
()
batch_size
,
seq_len
=
input_ids
.
shape
[:
2
]
if
self
.
enable_paged_attn
:
input_ids
=
input_ids
.
view
([
1
,
batch_size
*
seq_len
])
position_ids
=
infinicore
.
from_list
(
list
(
range
(
past_seq_len
,
past_seq_len
+
seq_len
))
*
batch_size
,
dtype
=
infinicore
.
int64
,
)
block_tables_list
=
[
[
i
*
batch_size
+
b
for
i
in
range
(
(
past_seq_len
+
seq_len
+
paged_block_size
-
1
)
//
paged_block_size
)
]
for
b
in
range
(
batch_size
)
]
slot_mapping_list
=
[
(((
past_seq_len
+
i
)
//
paged_block_size
)
*
batch_size
+
b
)
*
paged_block_size
+
(
past_seq_len
+
i
)
%
paged_block_size
for
b
in
range
(
batch_size
)
for
i
in
range
(
seq_len
)
]
block_tables
=
infinicore
.
from_list
(
block_tables_list
,
dtype
=
infinicore
.
int64
,
)
slot_mapping
=
infinicore
.
from_list
(
slot_mapping_list
,
dtype
=
infinicore
.
int64
,
)
else
:
position_ids
=
infinicore
.
from_list
(
[
list
(
range
(
past_seq_len
,
past_seq_len
+
seq_len
))
for
_
in
range
(
batch_size
)
],
dtype
=
infinicore
.
int64
,
)
block_tables
=
None
slot_mapping
=
None
past_kv_lengths
=
infinicore
.
from_list
(
[
past_seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
)
total_kv_lengths
=
infinicore
.
from_list
(
[
past_seq_len
+
seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
)
input_offsets
=
infinicore
.
from_list
(
[
seq_len
*
i
for
i
in
range
(
batch_size
+
1
)],
dtype
=
infinicore
.
int64
)
output_id
=
self
(
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
cache_lengths
=
cache_lengths
,
past_kv_lengths
=
past_kv_lengths
,
total_kv_lengths
=
total_kv_lengths
,
input_offsets
=
input_offsets
,
block_tables
=
block_tables
,
slot_mapping
=
slot_mapping
,
temperature
=
generation_config
.
temperature
,
top_k
=
generation_config
.
top_k
,
top_p
=
generation_config
.
top_p
,
...
...
@@ -127,24 +198,17 @@ class InferEngine(_infinilm.InferEngine):
output_ids
.
append
(
output_id
)
if
(
generation_config
.
max_new_tokens
is
not
None
initial_batch_size
==
1
and
generation_config
.
stop_on_eos
and
generation_config
.
max_new_tokens
is
not
None
and
output_id
.
to_numpy
()[
0
]
in
eos_token_id
):
break
seq_len
=
position_ids
.
shape
[
-
1
]
input_ids
=
infinicore
.
from_list
(
[[
output_id
]
for
output_id
in
output_id
.
to_numpy
().
tolist
()]
)
position_ids
=
infinicore
.
from_list
(
[
1
for
_
in
range
(
batch_size
)],
dtype
=
position_ids
.
dtype
,
device
=
position_ids
.
device
,
).
view
((
batch_size
,
1
))
+
position_ids
.
narrow
(
1
,
seq_len
-
1
,
1
)
cache_lengths
+=
infinicore
.
from_list
(
[
seq_len
],
dtype
=
cache_lengths
.
dtype
,
device
=
cache_lengths
.
device
)
past_seq_len
=
past_seq_len
+
seq_len
if
_measure_and_log_time
:
end_time
=
time
.
perf_counter
()
...
...
@@ -156,23 +220,21 @@ class InferEngine(_infinilm.InferEngine):
f
"
\n\n\n
Generation completed in
{
round
(
sum
(
time_measurements
)
*
1000
,
2
)
}
ms"
)
print
(
f
" Batchsize=
{
batch_size
}
Per_Batch_Input_Len=
{
seq
_
len
}
Per_Batch_New_Tokens=
{
len
(
time_measurements
)
}
\n
"
f
" Batchsize=
{
initial_
batch_size
}
Per_Batch_Input_Len=
{
initial_
seqlen
}
Per_Batch_New_Tokens=
{
len
(
time_measurements
)
}
\n
"
)
print
(
f
" Prefill TTFT:
{
round
(
time_measurements
[
0
],
2
)
}
ms Throughput:
{
round
((
batch_size
*
seq
_
len
)
/
time_measurements
[
0
],
2
)
}
tok/s
\n
"
,
f
" Prefill TTFT:
{
round
(
time_measurements
[
0
],
2
)
}
ms Throughput:
{
round
((
initial_
batch_size
*
initial_
seqlen
)
/
time_measurements
[
0
],
2
)
}
tok/s
\n
"
,
)
if
len
(
time_measurements
)
>
1
:
print
(
f
" Decode Avg ITL:
{
round
(
sum
(
time_measurements
[
1
:])
*
1000
/
(
len
(
time_measurements
)
-
1
),
2
)
}
ms Throughput:
{
round
((
batch_size
*
(
len
(
time_measurements
)
-
1
))
/
sum
(
time_measurements
[
1
:]),
2
)
}
tok/s
\n
"
,
f
" Decode Avg ITL:
{
round
(
sum
(
time_measurements
[
1
:])
*
1000
/
(
len
(
time_measurements
)
-
1
),
2
)
}
ms Throughput:
{
round
((
initial_
batch_size
*
(
len
(
time_measurements
)
-
1
))
/
sum
(
time_measurements
[
1
:]),
2
)
}
tok/s
\n
"
,
)
return
output_ids
def
reset_cache
(
self
,
batch_size
:
int
,
initial_capacity
:
int
=
1024
):
def
reset_cache
(
self
,
cache_config
):
infinicore
.
sync_device
()
cache_config
=
StaticKVCacheConfig
(
batch_size
,
initial_capacity
)
self
.
enable_paged_attn
=
isinstance
(
cache_config
,
PagedKVCacheConfig
)
super
().
reset_cache
(
cache_config
)
def
state_dict_keyname
(
self
):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment