Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
9dd10678
Unverified
Commit
9dd10678
authored
Dec 04, 2025
by
pengcheng888
Committed by
GitHub
Dec 04, 2025
Browse files
Merge pull request #101 from pengcheng888/issue/89
issue/89 在python的llama中使用matmul函数、以及减少Tensor对象创建次数
parents
36f8eab7
7e59976b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
37 additions
and
38 deletions
+37
-38
examples/llama.py
examples/llama.py
+4
-6
python/infinilm/cache_utils.py
python/infinilm/cache_utils.py
+4
-4
python/infinilm/generation/utils.py
python/infinilm/generation/utils.py
+5
-8
python/infinilm/models/llama/modeling_llama.py
python/infinilm/models/llama/modeling_llama.py
+24
-20
No files found.
examples/llama.py
View file @
9dd10678
...
...
@@ -86,6 +86,7 @@ def test(
infini_device
=
infinicore
.
device
(
"cpu"
,
0
),
backend
=
"python"
,
):
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
# 创建模型,
# ---------------------------------------------------------------------------- #
...
...
@@ -104,14 +105,12 @@ def test(
model
.
load_state_dict
(
model_param_infini
)
config
=
model
.
config
# ---------------------------------------------------------------------------- #
# 创建 tokenizer
# ---------------------------------------------------------------------------- #
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
"llama"
==
config
.
model_type
:
if
"llama"
==
model
.
config
.
model_type
:
backend
=
getattr
(
tokenizer
,
"backend_tokenizer"
,
None
)
target
=
getattr
(
backend
,
"_tokenizer"
,
backend
)
norm
=
getattr
(
target
,
"normalizer"
,
None
)
...
...
@@ -129,7 +128,7 @@ def test(
]
)
else
:
raise
ValueError
(
f
"Unsupported model type:
{
config
.
model_type
}
"
)
raise
ValueError
(
f
"Unsupported model type:
{
model
.
config
.
model_type
}
"
)
# ---------------------------------------------------------------------------- #
# token编码
...
...
@@ -162,7 +161,6 @@ def test(
max_new_tokens
=
max_new_tokens
,
device
=
infini_device
,
tokenizer
=
tokenizer
,
config
=
config
,
)
t2
=
time
.
time
()
...
...
python/infinilm/cache_utils.py
View file @
9dd10678
...
...
@@ -65,12 +65,12 @@ class DynamicLayer(CacheLayerMixin):
self
.
max_seq_len
=
max
(
self
.
max_position_embeddings
,
seq_len
)
self
.
keys
=
infinicore
.
empty
(
[
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
]
,
(
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
)
,
dtype
=
dtype
,
device
=
device
,
)
self
.
values
=
infinicore
.
empty
(
[
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
]
,
(
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
)
,
dtype
=
dtype
,
device
=
device
,
)
...
...
@@ -80,12 +80,12 @@ class DynamicLayer(CacheLayerMixin):
self
.
max_seq_len
=
max
(
self
.
max_seq_len
*
2
,
self
.
cache_position
+
seq_len
)
keys_new
=
infinicore
.
empty
(
[
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
]
,
(
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
)
,
dtype
=
dtype
,
device
=
device
,
)
values_new
=
infinicore
.
empty
(
[
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
]
,
(
batch_size
,
self
.
max_seq_len
,
num_heads
,
head_dim
)
,
dtype
=
dtype
,
device
=
device
,
)
...
...
python/infinilm/generation/utils.py
View file @
9dd10678
...
...
@@ -121,7 +121,6 @@ class GenerationMixin:
max_new_tokens
:
int
,
device
:
infinicore
.
device
,
tokenizer
,
config
,
**
kwargs
,
):
model_kwargs
=
kwargs
...
...
@@ -144,7 +143,6 @@ class GenerationMixin:
max_new_tokens
=
max_new_tokens
,
device
=
device
,
tokenizer
=
tokenizer
,
config
=
config
,
**
model_kwargs
,
)
return
result
...
...
@@ -155,7 +153,6 @@ class GenerationMixin:
max_new_tokens
:
int
,
device
:
infinicore
.
device
,
tokenizer
,
config
,
**
model_kwargs
,
):
r
"""
...
...
@@ -170,7 +167,7 @@ class GenerationMixin:
batch_size
,
seq_len
=
input_ids
.
shape
[:
2
]
eos_token_id
=
config
.
eos_token_id
eos_token_id
=
self
.
config
.
eos_token_id
eos_token_id_list
=
(
[
eos_token_id
]
if
isinstance
(
eos_token_id
,
int
)
else
eos_token_id
)
...
...
@@ -216,7 +213,7 @@ class GenerationMixin:
device
=
token_scores
.
device
,
)
for
i
in
range
(
0
,
batch_size
):
score
=
token_scores
.
narrow
(
0
,
i
,
1
).
view
(
[
vocab_size
]
)
score
=
token_scores
.
narrow
(
0
,
i
,
1
).
view
(
(
vocab_size
,)
)
out
=
next_tokens
.
narrow
(
0
,
i
,
1
).
view
([])
infinicore
.
nn
.
functional
.
random_sample
(
score
,
...
...
@@ -247,16 +244,16 @@ class GenerationMixin:
break
print
(
"
\n
</s>"
)
print
(
f
"
\n\n\n
Generation completed in
{
round
(
sum
(
time_list
),
2
)
}
ms"
)
print
(
f
"
\n\n\n
Generation completed in
{
round
(
sum
(
time_list
),
2
)
}
ms"
)
print
(
f
" Batchsize=
{
batch_size
}
Per_Batch_Input_Len=
{
seq_len
}
Per_Batch_New_Tokens=
{
len
(
time_list
)
}
\n
"
)
print
(
f
" Prefill TTFT:
{
round
(
time_list
[
0
],
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
seq_len
)
/
time_list
[
0
],
2
)
}
tok/s
\n
"
,
f
" Prefill TTFT:
{
round
(
time_list
[
0
],
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
seq_len
)
/
time_list
[
0
],
2
)
}
tok/s
\n
"
,
)
if
len
(
time_list
)
>
1
:
print
(
f
" Decode Avg ITL:
{
round
(
sum
(
time_list
[
1
:])
/
(
len
(
time_list
)
-
1
),
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
(
len
(
time_list
)
-
1
))
/
sum
(
time_list
[
1
:]),
2
)
}
tok/s
\n
"
,
f
" Decode Avg ITL:
{
round
(
sum
(
time_list
[
1
:])
/
(
len
(
time_list
)
-
1
),
2
)
}
ms Throughput:
{
round
((
1000
*
batch_size
*
(
len
(
time_list
)
-
1
))
/
sum
(
time_list
[
1
:]),
2
)
}
tok/s
\n
"
,
)
return
output_tokens_list
,
output_content
python/infinilm/models/llama/modeling_llama.py
View file @
9dd10678
...
...
@@ -62,13 +62,8 @@ def multi_head_attention(
# [num_heads, seq_len, head_dim] @ [ num_heads, head_dim, total_seq_len]
# => [ num_heads, seq_len, total_seq_len]
attn_weight
=
Q
@
K
.
permute
((
1
,
2
,
0
))
scaling
=
infinicore
.
from_list
(
[
scaling
],
dtype
=
attn_weight
.
dtype
,
device
=
attn_weight
.
device
).
as_strided
(
attn_weight
.
shape
,
[
0
,
0
,
0
])
attn_weight
=
attn_weight
*
scaling
# Q @ K.T *scaling
attn_weight
=
infinicore
.
matmul
(
Q
,
K
.
permute
((
1
,
2
,
0
)),
alpha
=
scaling
)
infinicore
.
nn
.
functional
.
causal_softmax
(
attn_weight
,
out
=
attn_weight
)
...
...
@@ -169,6 +164,8 @@ class LlamaAttention(infinicore.nn.Module):
**
kwargs
,
)
self
.
attn_output
=
None
# Variable reuse
def
forward
(
self
,
hidden_states
:
infinicore
.
Tensor
,
...
...
@@ -184,7 +181,7 @@ class LlamaAttention(infinicore.nn.Module):
values_shape
=
(
bs
,
seq_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
# --------------------------------------------------------------------------------------- #
# 对 Q,K
,
V进行 project
# 对 Q,K
,
V进行 project
# --------------------------------------------------------------------------------------- #
# => [bs, seq_len, num_attention_heads, head_dim]
query_states
=
self
.
q_proj
(
hidden_states
).
view
(
querys_shape
)
...
...
@@ -196,13 +193,9 @@ class LlamaAttention(infinicore.nn.Module):
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
values_shape
)
# --------------------------------------------------------------------------------------- #
# 对 Q和K
,
加上 rope
# 对 Q和K 加上 rope
# --------------------------------------------------------------------------------------- #
position_ids
=
kwargs
.
pop
(
"position_ids"
,
None
)
if
position_ids
is
None
:
raise
KeyError
(
"position_ids error"
)
if
rope_instance
is
None
:
raise
KeyError
(
"rope_instance error"
)
query_states
=
rope_instance
(
query_states
,
position_ids
)
key_states
=
rope_instance
(
key_states
,
position_ids
)
...
...
@@ -223,7 +216,14 @@ class LlamaAttention(infinicore.nn.Module):
# 注意力计算
# --------------------------------------------------------------------------------------- #
total_seq_len
=
key_states_total
.
shape
[
1
]
attn_output
=
infinicore
.
empty_like
(
query_states
)
if
self
.
attn_output
is
None
or
self
.
attn_output
.
shape
[
1
]
!=
seq_len
:
self
.
attn_output
=
infinicore
.
empty
(
(
bs
,
seq_len
,
self
.
num_attention_heads
,
self
.
head_dim
),
dtype
=
query_states
.
dtype
,
device
=
query_states
.
device
,
)
for
i
in
range
(
0
,
bs
):
query_states_i
=
query_states
.
narrow
(
0
,
i
,
1
).
view
(
(
seq_len
,
self
.
num_attention_heads
,
self
.
head_dim
)
...
...
@@ -235,7 +235,7 @@ class LlamaAttention(infinicore.nn.Module):
(
total_seq_len
,
self
.
num_key_value_heads
,
self
.
head_dim
)
)
attn_output_i
=
attn_output
.
narrow
(
0
,
i
,
1
).
view
(
attn_output_i
=
self
.
attn_output
.
narrow
(
0
,
i
,
1
).
view
(
(
seq_len
,
self
.
num_attention_heads
,
self
.
head_dim
)
)
...
...
@@ -249,8 +249,9 @@ class LlamaAttention(infinicore.nn.Module):
# out project
# --------------------------------------------------------------------------------------- #
# ([bs, seq_len, num_attention_heads, head_dim]) ==> [bs, seq_len, hidden_size ]
attn_output
=
attn_output
.
view
(
hidden_states_shape
)
attn_output
=
self
.
attn_output
.
view
(
(
bs
,
seq_len
,
self
.
num_attention_heads
*
self
.
head_dim
)
)
# o_proj
return
self
.
o_proj
(
attn_output
)
...
...
@@ -292,7 +293,7 @@ class LlamaDecoderLayer(infinicore.nn.Module):
**
kwargs
,
)
hidden_states
=
residual
+
hidden_states
hidden_states
+
=
residual
# ------------------------------------------------ #
# Fully Connected
...
...
@@ -303,7 +304,7 @@ class LlamaDecoderLayer(infinicore.nn.Module):
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
hidden_states
+
=
residual
return
hidden_states
...
...
@@ -375,7 +376,10 @@ class LlamaModel(infinicore.nn.Module):
# norm
# --------------------------------------------------------- #
seq_len
=
hidden_states
.
shape
[
1
]
last_token
=
hidden_states
.
narrow
(
1
,
seq_len
-
1
,
1
)
if
seq_len
>
1
:
last_token
=
hidden_states
.
narrow
(
1
,
seq_len
-
1
,
1
)
else
:
last_token
=
hidden_states
return
self
.
norm
(
last_token
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment