Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
e8245b7d
Commit
e8245b7d
authored
Dec 09, 2025
by
wooway777
Browse files
issue/116 - using batched rope
parent
81081f3c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
10 deletions
+9
-10
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+9
-10
No files found.
csrc/models/llama/llama_attention.cpp
View file @
e8245b7d
...
...
@@ -94,12 +94,17 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
}
// 4. Prepare KV caches
// 4. Apply RoPE to Q and K
auto
q_rope
=
infinicore
::
Tensor
::
empty
({
batch_size
,
num_attention_heads_
,
seq_len
,
head_dim_
},
q_reshaped
->
dtype
(),
q_reshaped
->
device
())
->
permute
({
0
,
2
,
1
,
3
});
rotary_emb_
->
forward
(
q_rope
,
q_reshaped
,
pos_ids_for_rope
);
// [bs, seq_len, n_q_head, head_dim]
rotary_emb_
->
forward
(
k_reshaped
,
pos_ids_for_rope
,
true
);
// [bs, seq_len, n_kv_head, head_dim]
// 5. Prepare KV caches
// Convert to [batch, n_head, seq_len, head_dim] for cache
// Ensure contiguous after permute for F16 compatibility with cache operations
q_reshaped
=
q_r
esha
pe
d
->
permute
({
0
,
2
,
1
,
3
})
->
contiguous
();
// [bs, n_q_head, seq_len, head_dim]
auto
k_permuted
=
k_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
q_reshaped
=
q_r
o
pe
->
permute
({
0
,
2
,
1
,
3
})
;
// [bs, n_q_head, seq_len, head_dim]
auto
k_permuted
=
k_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
auto
v_permuted
=
v_reshaped
->
permute
({
0
,
2
,
1
,
3
});
// [bs, n_kv_head, seq_len, head_dim]
infinilm
::
cache
::
DynamicCache
*
external_cache
=
static_cast
<
infinilm
::
cache
::
DynamicCache
*>
(
kv_cache
);
infinicore
::
Tensor
k_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
infinicore
::
Tensor
v_total
;
// [bs, n_kv_head, total_seq_len, head_dim]
...
...
@@ -113,12 +118,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
}
auto
total_seq_len
=
k_total
->
shape
()[
2
];
// 5. Apply RoPE to full batch
auto
q_rope
=
q_reshaped
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
head_dim_
})
->
permute
({
1
,
0
,
2
});
// [seq_len, bs * n_q_head, head_dim]
auto
k_rope
=
k_total
->
narrow
({{
2
,
total_seq_len
-
seq_len
,
seq_len
}})
->
view
({
batch_size
*
num_key_value_heads_
,
seq_len
,
head_dim_
})
->
permute
({
1
,
0
,
2
});
// [seq_len, bs * n_kv_head, head_dim]
rotary_emb_
->
forward
(
q_rope
,
pos_ids_for_rope
,
true
);
rotary_emb_
->
forward
(
k_rope
,
pos_ids_for_rope
,
true
);
// 6. Compute attention
size_t
ngroup
=
num_attention_heads_
/
num_key_value_heads_
;
auto
Q
=
q_reshaped
->
view
({
batch_size
*
num_key_value_heads_
,
ngroup
*
seq_len
,
head_dim_
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment