Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
6f624c94
Commit
6f624c94
authored
Dec 04, 2025
by
PanZezhong
Browse files
issue/97 推理脚本支持batch
parent
0da7b5db
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
58 additions
and
37 deletions
+58
-37
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+2
-0
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+14
-20
examples/llama.py
examples/llama.py
+29
-11
python/infinilm/generation/utils.py
python/infinilm/generation/utils.py
+13
-6
No files found.
csrc/cache/kv_cache.hpp
View file @
6f624c94
...
@@ -10,6 +10,8 @@
...
@@ -10,6 +10,8 @@
#include <stdexcept>
#include <stdexcept>
#include <utility>
#include <utility>
#include <spdlog/spdlog.h>
namespace
infinilm
::
cache
{
namespace
infinilm
::
cache
{
/**
/**
...
...
csrc/models/llama/llama_attention.cpp
View file @
6f624c94
...
@@ -72,44 +72,38 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -72,44 +72,38 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
}
}
// 4. Apply RoPE to full batch
// 4. Process each batch item separately for attention computation
auto
q_for_rope
=
q_reshaped
->
view
({
batch_size
*
seq_len
,
num_attention_heads_
,
head_dim_
});
auto
k_for_rope
=
k_reshaped
->
view
({
batch_size
*
seq_len
,
num_key_value_heads_
,
head_dim_
});
// Call RoPE on full batch (matching Python pattern)
auto
q_rope_out
=
rotary_emb_
->
forward
(
q_for_rope
,
pos_ids_for_rope
);
auto
k_rope_out
=
rotary_emb_
->
forward
(
k_for_rope
,
pos_ids_for_rope
);
// Reshape back to [batch_size, seq_len, num_heads, head_dim] (matching Python pattern)
q_rope_out
=
q_rope_out
->
view
({
batch_size
,
seq_len
,
num_attention_heads_
,
head_dim_
});
k_rope_out
=
k_rope_out
->
view
({
batch_size
,
seq_len
,
num_key_value_heads_
,
head_dim_
});
// 5. Process each batch item separately for attention computation
infinilm
::
cache
::
KVCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
KVCache
*>
(
kv_cache
);
infinilm
::
cache
::
KVCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
KVCache
*>
(
kv_cache
);
// Convert to [batch, n_head, seq_len, head_dim] for cache
// Convert to [batch, n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations
// Ensure contiguous after permute for F16 compatibility with cache operations
auto
q_ro
pe
=
q_r
ope_out
->
permute
({
0
,
2
,
1
,
3
})
->
contiguous
();
// [bs, n_q_head, seq_len, head_dim]
q_resha
pe
d
=
q_r
eshaped
->
permute
({
0
,
2
,
1
,
3
})
->
contiguous
();
// [bs, n_q_head, seq_len, head_dim]
auto
k_
rope
=
k_rope_out
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
k_
permuted
=
k_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
//
5
. Prepare KV caches
//
4
. Prepare KV caches
infinicore
::
Tensor
k_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
k_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
v_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
v_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
if
(
external_cache
!=
nullptr
)
{
if
(
external_cache
!=
nullptr
)
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
external_cache
->
update
(
k_
ro
pe
,
v_permuted
);
auto
[
k_total_tmp
,
v_total_tmp
]
=
external_cache
->
update
(
k_pe
rmuted
,
v_permuted
);
k_total
=
k_total_tmp
;
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
v_total
=
v_total_tmp
;
}
else
{
}
else
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
internal_cache_
.
update
(
k_
ro
pe
,
v_permuted
);
auto
[
k_total_tmp
,
v_total_tmp
]
=
internal_cache_
.
update
(
k_pe
rmuted
,
v_permuted
);
k_total
=
k_total_tmp
;
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
v_total
=
v_total_tmp
;
}
}
auto
total_seq_len
=
k_total
->
shape
()[
2
];
auto
total_seq_len
=
k_total
->
shape
()[
2
];
// 5. Apply RoPE to full batch
auto
q_rope
=
q_reshaped
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
head_dim_
})
->
permute
({
1
,
0
,
2
});
// [seq_len, bs * n_q_head, head_dim]
auto
k_rope
=
k_total
->
narrow
({{
2
,
total_seq_len
-
seq_len
,
seq_len
}})
->
view
({
batch_size
*
num_key_value_heads_
,
seq_len
,
head_dim_
})
->
permute
({
1
,
0
,
2
});
// [seq_len, bs * n_kv_head, head_dim]
rotary_emb_
->
forward
(
q_rope
,
pos_ids_for_rope
,
true
);
rotary_emb_
->
forward
(
k_rope
,
pos_ids_for_rope
,
true
);
// 6. Compute attention
// 6. Compute attention
size_t
ngroup
=
num_attention_heads_
/
num_key_value_heads_
;
size_t
ngroup
=
num_attention_heads_
/
num_key_value_heads_
;
auto
Q
=
q_r
o
pe
->
view
({
batch_size
*
num_key_value_heads_
,
ngroup
*
seq_len
,
head_dim_
});
auto
Q
=
q_r
esha
pe
d
->
view
({
batch_size
*
num_key_value_heads_
,
ngroup
*
seq_len
,
head_dim_
});
auto
K
=
k_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
auto
K
=
k_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
auto
V
=
v_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
auto
V
=
v_total
->
view
({
batch_size
*
num_key_value_heads_
,
total_seq_len
,
head_dim_
});
...
...
examples/llama.py
View file @
6f624c94
...
@@ -63,11 +63,23 @@ def get_args():
...
@@ -63,11 +63,23 @@ def get_args():
default
=
"float32"
,
default
=
"float32"
,
help
=
"float32, float16, bfloat16"
,
help
=
"float32, float16, bfloat16"
,
)
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"number of prompts in a batch"
,
)
parser
.
add_argument
(
"--prompt"
,
type
=
str
,
default
=
"How are you"
,
help
=
"input prompt"
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
def
test
(
def
test
(
prompt
,
prompt
s
:
str
|
list
[
str
]
,
model_path
,
model_path
,
max_new_tokens
=
100
,
max_new_tokens
=
100
,
infini_dtype
=
infinicore
.
bfloat16
,
infini_dtype
=
infinicore
.
bfloat16
,
...
@@ -123,18 +135,24 @@ def test(
...
@@ -123,18 +135,24 @@ def test(
# token编码
# token编码
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?"
# prompt = "山东最高的山是?"
input_content
=
tokenizer
.
apply_chat_template
(
if
isinstance
(
prompts
,
str
):
conversation
=
[{
"role"
:
"user"
,
"content"
:
prompt
}],
prompts
=
[
prompts
]
add_generation_prompt
=
True
,
input_contents
=
[
tokenize
=
False
,
tokenizer
.
apply_chat_template
(
)
conversation
=
[{
"role"
:
"user"
,
"content"
:
prompt
}],
print
(
input_content
,
end
=
""
,
flush
=
True
)
add_generation_prompt
=
True
,
input_ids
=
tokenizer
.
encode
(
input_content
)
tokenize
=
False
,
)
for
prompt
in
prompts
]
print
(
input_contents
[
0
],
end
=
""
,
flush
=
True
)
input_ids_list
=
tokenizer
.
batch_encode_plus
(
input_contents
)[
"input_ids"
]
# List: [[1, 1128, 526, 366, 29892]]
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# 自回归生成
# 自回归生成
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
input_ids_list
=
[
input_ids
]
# List: [[1, 1128, 526, 366, 29892]]
input_ids_infini
=
infinicore
.
from_list
(
input_ids_list
)
input_ids_infini
=
infinicore
.
from_list
(
input_ids_list
)
t1
=
time
.
time
()
t1
=
time
.
time
()
...
@@ -175,7 +193,7 @@ if __name__ == "__main__":
...
@@ -175,7 +193,7 @@ if __name__ == "__main__":
"such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
"such as, python examples/llama.py --nvidia --model_path=~/TinyLlama-1.1B-Chat-v1.0"
)
)
sys
.
exit
(
1
)
sys
.
exit
(
1
)
prompt
=
"How are you"
prompt
s
=
[
args
.
prompt
for
_
in
range
(
args
.
batch_size
)]
model_path
=
args
.
model_path
model_path
=
args
.
model_path
max_new_tokens
=
args
.
max_new_tokens
max_new_tokens
=
args
.
max_new_tokens
...
@@ -192,7 +210,7 @@ if __name__ == "__main__":
...
@@ -192,7 +210,7 @@ if __name__ == "__main__":
raise
ValueError
(
f
"Unsupported dtype:
{
args
.
dtype
}
"
)
raise
ValueError
(
f
"Unsupported dtype:
{
args
.
dtype
}
"
)
test
(
test
(
prompt
,
prompt
s
,
model_path
,
model_path
,
max_new_tokens
,
max_new_tokens
,
infini_device
=
infini_device
,
infini_device
=
infini_device
,
...
...
python/infinilm/generation/utils.py
View file @
6f624c94
...
@@ -100,9 +100,11 @@ class GenerationMixin:
...
@@ -100,9 +100,11 @@ class GenerationMixin:
# -------------------------------------------------------------------- #
# -------------------------------------------------------------------- #
# 所需的: token的input_ids
# 所需的: token的input_ids
# -------------------------------------------------------------------- #
# -------------------------------------------------------------------- #
if
kwargs
.
get
(
"next_token_id"
,
None
)
is
not
None
:
if
kwargs
.
get
(
"next_token_ids"
,
None
)
is
not
None
:
next_token_id
=
kwargs
[
"next_token_id"
]
next_token_ids
=
kwargs
[
"next_token_ids"
]
model_inputs
[
"input_ids"
]
=
infinicore
.
from_list
([[
next_token_id
]])
model_inputs
[
"input_ids"
]
=
infinicore
.
from_list
(
[[
id_
]
for
id_
in
next_token_ids
],
)
# -------------------------------------------------------------------- #
# -------------------------------------------------------------------- #
# 其他
# 其他
...
@@ -236,7 +238,7 @@ class GenerationMixin:
...
@@ -236,7 +238,7 @@ class GenerationMixin:
token_id
=
next_tokens
.
to_numpy
()[
0
]
token_id
=
next_tokens
.
to_numpy
()[
0
]
output_str
=
tokenizer
.
decode
([
token_id
],
skip_special_tokens
=
True
)
output_str
=
tokenizer
.
decode
([
token_id
],
skip_special_tokens
=
True
)
model_kwargs
[
"next_token_id"
]
=
token_id
model_kwargs
[
"next_token_id
s
"
]
=
next_tokens
.
to_numpy
().
tolist
()
output_tokens_list
.
append
(
token_id
)
output_tokens_list
.
append
(
token_id
)
output_content
+=
output_str
output_content
+=
output_str
...
@@ -245,11 +247,16 @@ class GenerationMixin:
...
@@ -245,11 +247,16 @@ class GenerationMixin:
break
break
print
(
"
\n
</s>"
)
print
(
"
\n
</s>"
)
print
(
f
"
\n\n\n
Generation completed in
{
round
(
sum
(
time_list
),
2
)
}
ms"
)
print
(
print
(
f
"
\n\n\n
Time per step: prefill
{
round
(
time_list
[
0
],
2
)
}
ms/token
\n
"
,
f
"
Batchsize=
{
batch_size
}
Per_Batch_Input_Len=
{
seq_len
}
Per_Batch_New_Tokens=
{
len
(
time_list
)
}
\n
"
)
)
print
(
print
(
f
"
Time per step: decoder
{
round
(
sum
(
time_list
[
1
:])
/
(
len
(
time_list
)
-
1
),
2
)
}
ms/token
\n
"
,
f
"
Prefill TTFT:
{
round
(
time_list
[
0
],
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
seq_
len
)
/
time_list
[
0
],
2
)
}
tok/s
\n
"
,
)
)
if
len
(
time_list
)
>
1
:
print
(
f
" Decode Avg ITL:
{
round
(
sum
(
time_list
[
1
:])
/
(
len
(
time_list
)
-
1
),
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
(
len
(
time_list
)
-
1
))
/
sum
(
time_list
[
1
:]),
2
)
}
tok/s
\n
"
,
)
return
output_tokens_list
,
output_content
return
output_tokens_list
,
output_content
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment