Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
f246c4f1
Commit
f246c4f1
authored
Dec 30, 2025
by
PanZezhong
Browse files
issue/168 InfiniLM接入paged attention
parent
96e53dbb
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
321 additions
and
81 deletions
+321
-81
csrc/cache/kv_cache.cpp
csrc/cache/kv_cache.cpp
+7
-3
csrc/engine/infer_engine.cpp
csrc/engine/infer_engine.cpp
+44
-2
csrc/engine/rank_worker.cpp
csrc/engine/rank_worker.cpp
+10
-5
csrc/engine/rank_worker.hpp
csrc/engine/rank_worker.hpp
+1
-1
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+120
-21
csrc/models/llama/llama_attention.hpp
csrc/models/llama/llama_attention.hpp
+18
-1
csrc/models/llama/llama_for_causal_lm.cpp
csrc/models/llama/llama_for_causal_lm.cpp
+1
-2
csrc/models/llama/llama_model.cpp
csrc/models/llama/llama_model.cpp
+1
-9
csrc/pybind11/engine/engine.hpp
csrc/pybind11/engine/engine.hpp
+2
-0
examples/jiuge.py
examples/jiuge.py
+31
-6
python/infinilm/auto_config.py
python/infinilm/auto_config.py
+2
-0
python/infinilm/cache/__init__.py
python/infinilm/cache/__init__.py
+2
-2
python/infinilm/generation/utils.py
python/infinilm/generation/utils.py
+2
-0
python/infinilm/infer_engine.py
python/infinilm/infer_engine.py
+80
-29
No files found.
csrc/cache/kv_cache.cpp
View file @
f246c4f1
#include "kv_cache.hpp"
#include "kv_cache.hpp"
#include "../utils.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include <stdexcept>
#include <stdexcept>
namespace
infinilm
::
cache
{
namespace
infinilm
::
cache
{
...
@@ -155,6 +155,7 @@ PagedKVCache::PagedKVCache(
...
@@ -155,6 +155,7 @@ PagedKVCache::PagedKVCache(
num_blocks_per_layer_
=
config
.
max_kv_memory_bytes
()
num_blocks_per_layer_
=
config
.
max_kv_memory_bytes
()
/
(
k_dim
*
num_rank_k_heads_
+
v_dim
*
num_rank_v_heads_
)
/
(
k_dim
*
num_rank_k_heads_
+
v_dim
*
num_rank_v_heads_
)
/
block_size_
/
block_size_
/
rank_num_layers_
/
infinicore
::
dsize
(
dtype_
);
/
infinicore
::
dsize
(
dtype_
);
if
(
num_blocks_per_layer_
==
0
)
{
if
(
num_blocks_per_layer_
==
0
)
{
throw
std
::
runtime_error
(
"Not enough memory for KV cache"
);
throw
std
::
runtime_error
(
"Not enough memory for KV cache"
);
...
@@ -190,8 +191,11 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
...
@@ -190,8 +191,11 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
auto
k_cache_layer
=
k_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
k_cache_layer
=
k_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
v_cache_layer
=
v_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
v_cache_layer
=
v_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
/// @todo: implement paged cache update here
infinicore
::
op
::
paged_caching_
(
k
,
v
,
k_cache_layer
,
v_cache_layer
,
slot_mapping
);
return
{
k_cache_layer
,
v_cache_layer
};
return
{
k_cache_layer
,
v_cache_layer
};
}
}
}
// namespace infinilm::cache
}
// namespace infinilm::cache
csrc/engine/infer_engine.cpp
View file @
f246c4f1
...
@@ -56,8 +56,50 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
...
@@ -56,8 +56,50 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------
//------------------------------------------------------
// forward
// forward
//------------------------------------------------------
//------------------------------------------------------
infinilm
::
InfinilmModel
::
Input
InferEngine
::
Input
::
to_model_input
()
const
{
infinilm
::
InfinilmModel
::
Input
InferEngine
::
Input
::
to_model_input
(
infinicore
::
Device
device
)
const
{
return
{
input_ids
,
position_ids
,
cache_lengths
,
input_lengths
,
input_offsets
,
block_tables
,
slot_mapping
};
std
::
optional
<
infinicore
::
Tensor
>
position_ids_on_device
;
if
(
position_ids
.
has_value
())
{
position_ids_on_device
=
position_ids
.
value
()
->
to
(
device
);
}
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths_on_device
;
if
(
cache_lengths
.
has_value
())
{
if
(
block_tables
.
has_value
())
{
cache_lengths_on_device
=
cache_lengths
.
value
()
->
to
(
device
);
}
else
{
// @todo: only paged kv cache support device tensor so far
cache_lengths_on_device
=
cache_lengths
.
value
();
}
}
std
::
optional
<
infinicore
::
Tensor
>
input_lengths_on_device
;
if
(
input_lengths
.
has_value
())
{
input_lengths_on_device
=
input_lengths
.
value
()
->
to
(
device
);
}
std
::
optional
<
infinicore
::
Tensor
>
input_offsets_on_device
;
if
(
input_offsets
.
has_value
())
{
input_offsets_on_device
=
input_offsets
.
value
()
->
to
(
device
);
}
std
::
optional
<
infinicore
::
Tensor
>
block_tables_on_device
;
if
(
block_tables
.
has_value
())
{
block_tables_on_device
=
block_tables
.
value
()
->
to
(
device
);
}
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping_on_device
;
if
(
slot_mapping
.
has_value
())
{
slot_mapping_on_device
=
slot_mapping
.
value
()
->
to
(
device
);
}
return
{
input_ids
,
// @todo: on device in the future
position_ids_on_device
,
cache_lengths_on_device
,
input_lengths_on_device
,
input_offsets_on_device
,
block_tables_on_device
,
slot_mapping_on_device
};
}
}
InferEngine
::
Output
InferEngine
::
forward
(
const
InferEngine
::
Input
&
input
)
{
InferEngine
::
Output
InferEngine
::
forward
(
const
InferEngine
::
Input
&
input
)
{
...
...
csrc/engine/rank_worker.cpp
View file @
f246c4f1
...
@@ -206,7 +206,7 @@ void RankWorker::thread_loop() {
...
@@ -206,7 +206,7 @@ void RankWorker::thread_loop() {
local_param_name
=
pending_param_name_
;
local_param_name
=
pending_param_name_
;
local_param
=
pending_param_
;
local_param
=
pending_param_
;
}
else
if
(
local_cmd
==
Command
::
RUN
)
{
}
else
if
(
local_cmd
==
Command
::
RUN
)
{
local_args
=
pending_args_
.
to_model_input
();
local_args
=
pending_args_
.
to_model_input
(
rank_info_
.
device
);
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
if
(
pending_cache_config_
!=
nullptr
)
{
if
(
pending_cache_config_
!=
nullptr
)
{
local_cache_config
=
pending_cache_config_
->
unique_copy
();
local_cache_config
=
pending_cache_config_
->
unique_copy
();
...
@@ -254,13 +254,18 @@ void RankWorker::thread_loop() {
...
@@ -254,13 +254,18 @@ void RankWorker::thread_loop() {
auto
random_val
{
pending_args_
.
random_val
};
auto
random_val
{
pending_args_
.
random_val
};
const
auto
&
logits_shape
{
logits
->
shape
()};
const
auto
&
logits_shape
{
logits
->
shape
()};
const
auto
&
batch_size
{
logits_shape
[
0
]};
const
auto
&
vocab_size
{
logits_shape
[
2
]};
const
auto
&
vocab_size
{
logits_shape
[
2
]};
const
auto
&
total_len
{
logits_shape
[
1
]};
const
auto
&
batch_size
{
logits_shape
[
0
]};
auto
n_req
=
pending_args_
.
input_offsets
.
value
()
->
size
(
0
);
int64_t
*
input_lengths
=
(
int64_t
*
)
pending_args_
.
input_lengths
.
value
()
->
data
();
int64_t
*
input_offsets
=
(
int64_t
*
)
pending_args_
.
input_offsets
.
value
()
->
data
();
auto
output_ids
{
infinicore
::
Tensor
::
empty
({
batch_size
},
infinicore
::
DataType
::
I
32
,
rank_info_
.
device
)};
auto
output_ids
{
infinicore
::
Tensor
::
empty
({
n_req
},
infinicore
::
DataType
::
I
64
,
rank_info_
.
device
)};
for
(
auto
i
{
decltype
(
batch_size
)(
0
)};
i
<
batch_size
;
++
i
)
{
for
(
auto
i
{
decltype
(
n_req
)(
0
)};
i
<
n_req
;
++
i
)
{
auto
score
{
logits
->
narrow
({{
0
,
i
,
1
}})
->
view
({
vocab_size
})};
auto
score
{
logits
->
view
({
batch_size
*
total_len
,
vocab_size
})
->
narrow
({{
0
,
size_t
(
input_offsets
[
i
]
+
input_lengths
[
i
]
-
1
)
,
1
}})
->
view
({
vocab_size
})};
auto
out
{
output_ids
->
narrow
({{
0
,
i
,
1
}})
->
view
({})};
auto
out
{
output_ids
->
narrow
({{
0
,
i
,
1
}})
->
view
({})};
infinicore
::
op
::
random_sample_
(
infinicore
::
op
::
random_sample_
(
out
,
score
,
random_val
,
top_p
,
top_k
,
temperature
);
out
,
score
,
random_val
,
top_p
,
top_k
,
temperature
);
...
...
csrc/engine/rank_worker.hpp
View file @
f246c4f1
...
@@ -47,7 +47,7 @@ public:
...
@@ -47,7 +47,7 @@ public:
float
random_val
{
0.1
};
float
random_val
{
0.1
};
infinilm
::
InfinilmModel
::
Input
to_model_input
()
const
;
infinilm
::
InfinilmModel
::
Input
to_model_input
(
infinicore
::
Device
device
)
const
;
};
};
struct
Output
{
struct
Output
{
...
...
csrc/models/llama/llama_attention.cpp
View file @
f246c4f1
#include "llama_attention.hpp"
#include "llama_attention.hpp"
#include "../../utils.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
#include "infinicore/ops.hpp"
...
@@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
...
@@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
}
else
{
}
else
{
throw
std
::
runtime_error
(
"num_attention_heads / tp_size error."
);
throw
std
::
runtime_error
(
"num_attention_heads / tp_size error."
);
}
}
scaling_
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
// Initialize projection layers
// Initialize projection layers
INFINILM_QKV_LINEAR_INIT
(
qkv_proj
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
hidden_size_
,
head_dim_
,
config
.
num_attention_heads
,
config
.
num_key_value_heads
,
use_bias_
,
INFINILM_QKV_LINEAR_INIT
(
qkv_proj
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
hidden_size_
,
head_dim_
,
config
.
num_attention_heads
,
config
.
num_key_value_heads
,
use_bias_
,
...
@@ -52,17 +54,10 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
...
@@ -52,17 +54,10 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
dtype
,
device
,
tp_rank
,
tp_size
,
rank_info
.
comm
);
dtype
,
device
,
tp_rank
,
tp_size
,
rank_info
.
comm
);
}
}
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
LlamaAttention
::
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
)
const
{
std
::
optional
<
infinicore
::
Tensor
>
input_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
if
(
!
rotary_emb_
)
{
throw
std
::
runtime_error
(
"LlamaAttention: rotary_emb not configured"
);
}
// Input shape: [batch, seq_len, hidden_size]
// Input shape: [batch, seq_len, hidden_size]
auto
hidden_states_mutable
=
hidden_states
;
auto
hidden_states_mutable
=
hidden_states
;
auto
shape
=
hidden_states
->
shape
();
auto
shape
=
hidden_states
->
shape
();
...
@@ -73,7 +68,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -73,7 +68,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
// 2. Reshape for multi-head attention
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions
// The view operation requires the tensor to be contiguous in the required dimensions
...
@@ -114,13 +108,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -114,13 +108,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto
[
k_total_tmp
,
v_total_tmp
]
=
static_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
cache_lengths
.
value
());
auto
[
k_total_tmp
,
v_total_tmp
]
=
static_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
cache_lengths
.
value
());
k_total
=
k_total_tmp
;
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
v_total
=
v_total_tmp
;
}
else
if
(
auto
paged_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
PagedKVCache
>
(
kv_cache
))
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
paged_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
slot_mapping
.
value
());
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
/// @todo Implement paged attention here.
throw
std
::
runtime_error
(
"LlamaAttention: Paged attention not implemented"
);
}
else
{
}
else
{
throw
std
::
runtime_error
(
"LlamaAttention: Unsupported kvcache type"
);
throw
std
::
runtime_error
(
"LlamaAttention: Unsupported kvcache type"
);
}
}
...
@@ -134,8 +121,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -134,8 +121,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto
K_transposed
=
K
->
permute
({
0
,
2
,
1
});
// [bs * n_kv_head, head_dim, total_seq_len]
auto
K_transposed
=
K
->
permute
({
0
,
2
,
1
});
// [bs * n_kv_head, head_dim, total_seq_len]
float
scaling
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling_
);
// [bs * n_kv_head, ng * seq_len, total_seq_len]
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling
);
// [bs * n_kv_head, ng * seq_len, total_seq_len]
auto
attn_weight_softmax
=
attn_weight
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
total_seq_len
});
auto
attn_weight_softmax
=
attn_weight
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
total_seq_len
});
infinicore
::
op
::
causal_softmax_
(
attn_weight_softmax
,
attn_weight_softmax
);
infinicore
::
op
::
causal_softmax_
(
attn_weight_softmax
,
attn_weight_softmax
);
...
@@ -152,6 +138,119 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -152,6 +138,119 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
return
output
;
return
output
;
}
}
infinicore
::
Tensor
LlamaAttention
::
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
paged_kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
ASSERT
(
block_tables
.
has_value
());
ASSERT
(
input_lengths
.
has_value
());
ASSERT
(
slot_mapping
.
has_value
());
// Input shape: [batch, seq_len, hidden_size]
auto
hidden_states_mutable
=
hidden_states
;
auto
shape
=
hidden_states
->
shape
();
size_t
batch_size
=
shape
[
0
];
size_t
seq_len
=
shape
[
1
];
// Only support batchsize==1, all requests should be flattened along seqlen dimension
ASSERT_EQ
(
batch_size
,
1
);
// Decode only if total_len == num_requests
bool
is_prefill
=
(
seq_len
!=
input_lengths
.
value
()
->
shape
()[
0
]);
// 1. Project Q, K, V
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions
auto
q_reshaped
=
q
->
view
({
seq_len
,
num_attention_heads_
,
head_dim_
});
auto
k_reshaped
=
k
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
auto
v_reshaped
=
v
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
// 3. Prepare position_ids for RoPE - align with Python pattern
auto
pos_shape
=
position_ids
->
shape
();
infinicore
::
Tensor
pos_ids_for_rope
=
position_ids
;
if
(
pos_shape
.
size
()
==
2
)
{
auto
pos_narrowed
=
position_ids
->
narrow
({{
0
,
0
,
1
}});
pos_ids_for_rope
=
pos_narrowed
->
view
({
pos_shape
[
1
]});
}
else
if
(
pos_shape
.
size
()
==
1
)
{
pos_ids_for_rope
=
position_ids
;
}
else
{
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
}
// 4. Apply RoPE to Q and K
rotary_emb_
->
forward
(
q_reshaped
,
pos_ids_for_rope
,
true
);
// [bs, seq_len, n_q_head, head_dim]
rotary_emb_
->
forward
(
k_reshaped
,
pos_ids_for_rope
,
true
);
// [bs, seq_len, n_kv_head, head_dim]
// 5. Prepare KV caches
// Ensure contiguous after permute for F16 compatibility with cache operations
auto
[
k_total
,
v_total
]
=
paged_kv_cache
->
update
(
layer_idx_
,
k_reshaped
,
v_reshaped
,
slot_mapping
.
value
());
// 6. Compute attention
infinicore
::
Tensor
attn_output
=
infinicore
::
Tensor
::
empty
({
seq_len
,
num_attention_heads_
,
head_dim_
},
q_reshaped
->
dtype
(),
q_reshaped
->
device
());
if
(
is_prefill
)
{
infinicore
::
op
::
paged_attention_prefill_
(
attn_output
,
q_reshaped
,
k_total
,
v_total
,
block_tables
.
value
(),
cache_lengths
.
value
(),
input_lengths
.
value
(),
input_offsets
.
value
(),
std
::
nullopt
,
scaling_
);
}
else
{
infinicore
::
op
::
paged_attention_
(
attn_output
,
q_reshaped
,
k_total
,
v_total
,
block_tables
.
value
(),
cache_lengths
.
value
(),
std
::
nullopt
,
scaling_
);
}
// 7. Project output
attn_output
=
attn_output
->
view
({
1
,
seq_len
,
num_attention_heads_
*
head_dim_
});
return
o_proj_
->
forward
(
attn_output
);
}
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
if
(
!
rotary_emb_
)
{
throw
std
::
runtime_error
(
"LlamaAttention: rotary_emb not configured"
);
}
infinicore
::
Tensor
output
;
if
(
auto
paged_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
PagedKVCache
>
(
kv_cache
))
{
output
=
forward_paged_
(
hidden_states
,
position_ids
,
paged_kv_cache
,
cache_lengths
,
input_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
}
else
{
output
=
forward_
(
hidden_states
,
position_ids
,
kv_cache
,
cache_lengths
);
}
return
output
;
}
void
LlamaAttention
::
set_rotary_emb
(
const
std
::
shared_ptr
<
infinicore
::
nn
::
RoPE
>
&
rotary_emb
)
{
void
LlamaAttention
::
set_rotary_emb
(
const
std
::
shared_ptr
<
infinicore
::
nn
::
RoPE
>
&
rotary_emb
)
{
rotary_emb_
=
rotary_emb
;
rotary_emb_
=
rotary_emb
;
}
}
...
...
csrc/models/llama/llama_attention.hpp
View file @
f246c4f1
...
@@ -55,7 +55,7 @@ public:
...
@@ -55,7 +55,7 @@ public:
std
::
optional
<
infinicore
::
Tensor
>
input_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mappin
)
const
;
std
::
optional
<
infinicore
::
Tensor
>
slot_mappin
g
)
const
;
/**
/**
* @brief Get the layer index
* @brief Get the layer index
...
@@ -73,6 +73,21 @@ public:
...
@@ -73,6 +73,21 @@ public:
size_t
head_dim
()
const
{
return
head_dim_
;
}
size_t
head_dim
()
const
{
return
head_dim_
;
}
size_t
hidden_size
()
const
{
return
hidden_size_
;
}
size_t
hidden_size
()
const
{
return
hidden_size_
;
}
private:
infinicore
::
Tensor
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
)
const
;
infinicore
::
Tensor
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
protected:
protected:
// Projection layers
// Projection layers
INFINICORE_NN_MODULE
(
infinilm
::
layers
::
QKVParallelLinear
,
qkv_proj
);
INFINICORE_NN_MODULE
(
infinilm
::
layers
::
QKVParallelLinear
,
qkv_proj
);
...
@@ -93,6 +108,8 @@ private:
...
@@ -93,6 +108,8 @@ private:
bool
use_bias_
;
// Bias for Q/K/V projections
bool
use_bias_
;
// Bias for Q/K/V projections
bool
use_output_bias_
;
// Bias for output projection (o_proj)
bool
use_output_bias_
;
// Bias for output projection (o_proj)
size_t
max_position_embeddings_
;
// For cache initialization (deprecated, kept for compatibility)
size_t
max_position_embeddings_
;
// For cache initialization (deprecated, kept for compatibility)
float
scaling_
;
};
};
}
// namespace infinilm::models::llama
}
// namespace infinilm::models::llama
csrc/models/llama/llama_for_causal_lm.cpp
View file @
f246c4f1
...
@@ -35,8 +35,7 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
...
@@ -35,8 +35,7 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
auto
slot_mapping
=
input
.
slot_mapping
;
auto
slot_mapping
=
input
.
slot_mapping
;
// 1. Forward through base model to get hidden states
// 1. Forward through base model to get hidden states
auto
position_ids_device
=
position_ids
->
to
(
device_
);
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids
,
cache_lengths
,
input_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids_device
,
cache_lengths
,
input_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
// 2. Apply language modeling head to get logits
// 2. Apply language modeling head to get logits
auto
logits
=
lm_head_
->
forward
(
hidden_states
);
auto
logits
=
lm_head_
->
forward
(
hidden_states
);
...
...
csrc/models/llama/llama_model.cpp
View file @
f246c4f1
...
@@ -59,15 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
...
@@ -59,15 +59,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
kv_cache_
,
cache_lengths
,
input_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
kv_cache_
,
cache_lengths
,
input_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
}
}
// 3. Apply final layer normalization to last token only (aligns with transformers)
return
norm_
->
forward
(
hidden_states
);
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto
shape
=
hidden_states
->
shape
();
size_t
seq_len
=
shape
[
1
];
auto
last_token
=
hidden_states
->
narrow
({{
1
,
seq_len
-
1
,
1
}});
auto
normalized_last_token
=
norm_
->
forward
(
last_token
);
return
normalized_last_token
;
}
}
void
LlamaModel
::
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
{
void
LlamaModel
::
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
{
...
...
csrc/pybind11/engine/engine.hpp
View file @
f246c4f1
...
@@ -90,6 +90,8 @@ inline void bind_infer_engine(py::module &m) {
...
@@ -90,6 +90,8 @@ inline void bind_infer_engine(py::module &m) {
std
::
move
(
input_ids
),
std
::
move
(
input_ids
),
std
::
move
(
position_ids
),
std
::
move
(
position_ids
),
std
::
move
(
cache_lengths
),
std
::
move
(
cache_lengths
),
std
::
move
(
input_lengths
),
std
::
move
(
input_offsets
),
std
::
move
(
block_tables
),
std
::
move
(
block_tables
),
std
::
move
(
slot_mapping
)}};
std
::
move
(
slot_mapping
)}};
...
...
examples/jiuge.py
View file @
f246c4f1
...
@@ -9,6 +9,7 @@ import sys
...
@@ -9,6 +9,7 @@ import sys
import
time
import
time
import
os
import
os
import
numpy
as
np
import
numpy
as
np
from
infinilm.cache
import
StaticKVCacheConfig
,
PagedKVCacheConfig
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../python"
))
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../python"
))
...
@@ -82,6 +83,18 @@ def get_args():
...
@@ -82,6 +83,18 @@ def get_args():
default
=
1
,
default
=
1
,
help
=
"total rank for tensor parallel"
,
help
=
"total rank for tensor parallel"
,
)
)
parser
.
add_argument
(
"--enable-paged-attn"
,
action
=
"store_true"
,
help
=
"use paged cache"
,
)
parser
.
add_argument
(
"--max-kvcache-size"
,
type
=
int
,
default
=
8
*
1024
*
1024
*
1024
,
help
=
"max size (in bytes) allocated to paged kv cache"
,
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -92,6 +105,7 @@ def test(
...
@@ -92,6 +105,7 @@ def test(
max_new_tokens
=
100
,
max_new_tokens
=
100
,
infini_device
=
infinicore
.
device
(
"cpu"
,
0
),
infini_device
=
infinicore
.
device
(
"cpu"
,
0
),
tp
=
1
,
tp
=
1
,
enable_paged_attn
=
False
,
):
):
model_path
=
os
.
path
.
expanduser
(
model_path
)
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
...
@@ -150,11 +164,21 @@ def test(
...
@@ -150,11 +164,21 @@ def test(
"input_ids"
"input_ids"
]
# List: [[1, 1128, 526, 366, 29892]]
]
# List: [[1, 1128, 526, 366, 29892]]
# 根据输入长度和最长输出长度创建KVCache
# ---------------------------------------------------------------------------- #
model
.
reset_cache
(
# 创建KVCache
1
if
prompts
is
str
else
len
(
prompts
),
# ---------------------------------------------------------------------------- #
max_new_tokens
+
len
(
input_ids_list
[
0
]),
if
enable_paged_attn
:
)
cache_config
=
PagedKVCacheConfig
(
max_kv_memory_bytes
=
args
.
max_kvcache_size
,
block_size
=
16
)
else
:
batch_size
=
1
if
prompts
is
str
else
len
(
prompts
)
initial_capacity
=
max_new_tokens
+
len
(
input_ids_list
[
0
])
cache_config
=
StaticKVCacheConfig
(
max_batch_size
=
batch_size
,
max_cache_len
=
initial_capacity
)
model
.
reset_cache
(
cache_config
)
# ---------------------------------------------------------------------------- #
# ---------------------------------------------------------------------------- #
# 自回归生成
# 自回归生成
...
@@ -211,7 +235,7 @@ if __name__ == "__main__":
...
@@ -211,7 +235,7 @@ if __name__ == "__main__":
max_new_tokens
=
args
.
max_new_tokens
max_new_tokens
=
args
.
max_new_tokens
backend
=
args
.
backend
backend
=
args
.
backend
tp
=
args
.
tp
tp
=
args
.
tp
enable_paged_attn
=
args
.
enable_paged_attn
if
backend
!=
"cpp"
:
if
backend
!=
"cpp"
:
raise
ValueError
(
f
"Unsupported backend:
{
backend
}
."
)
raise
ValueError
(
f
"Unsupported backend:
{
backend
}
."
)
...
@@ -223,4 +247,5 @@ if __name__ == "__main__":
...
@@ -223,4 +247,5 @@ if __name__ == "__main__":
max_new_tokens
,
max_new_tokens
,
infini_device
=
infini_device
,
infini_device
=
infini_device
,
tp
=
tp
,
tp
=
tp
,
enable_paged_attn
=
enable_paged_attn
,
)
)
python/infinilm/auto_config.py
View file @
f246c4f1
...
@@ -21,5 +21,7 @@ class AutoConfig:
...
@@ -21,5 +21,7 @@ class AutoConfig:
if
config_dict
[
"model_type"
]
==
"llama"
:
if
config_dict
[
"model_type"
]
==
"llama"
:
return
LlamaConfig
(
**
config_dict
)
return
LlamaConfig
(
**
config_dict
)
elif
config_dict
[
"model_type"
]
==
"qwen2"
:
return
LlamaConfig
(
**
config_dict
)
raise
ValueError
(
f
"Unsupported model type `
{
config_dict
[
'model_type'
]
}
`."
)
raise
ValueError
(
f
"Unsupported model type `
{
config_dict
[
'model_type'
]
}
`."
)
python/infinilm/cache/__init__.py
View file @
f246c4f1
from
.cache
import
CacheConfig
,
StaticKVCacheConfig
from
.cache
import
CacheConfig
,
StaticKVCacheConfig
,
PagedKVCacheConfig
__all__
=
[
"CacheConfig"
,
"StaticKVCacheConfig"
]
__all__
=
[
"CacheConfig"
,
"StaticKVCacheConfig"
,
"PagedKVCacheConfig"
]
python/infinilm/generation/utils.py
View file @
f246c4f1
...
@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype):
...
@@ -13,6 +13,8 @@ def infini_to_ctype_dtype(infini_dtype):
return
ctypes
.
c_int32
return
ctypes
.
c_int32
elif
infini_dtype
==
infinicore
.
float32
:
elif
infini_dtype
==
infinicore
.
float32
:
return
ctypes
.
c_float
return
ctypes
.
c_float
elif
infini_dtype
==
infinicore
.
int64
:
return
ctypes
.
c_int64
else
:
else
:
raise
ValueError
(
f
"Unsupported py_dtype:
{
infini_dtype
}
"
)
raise
ValueError
(
f
"Unsupported py_dtype:
{
infini_dtype
}
"
)
...
...
python/infinilm/infer_engine.py
View file @
f246c4f1
...
@@ -4,7 +4,7 @@ from dataclasses import dataclass
...
@@ -4,7 +4,7 @@ from dataclasses import dataclass
import
infinicore
import
infinicore
from
infinilm.auto_config
import
AutoConfig
from
infinilm.auto_config
import
AutoConfig
from
infinilm.cache
import
StaticKVCacheConfig
from
infinilm.cache
import
StaticKVCacheConfig
,
PagedKVCacheConfig
from
infinilm.distributed
import
DistConfig
from
infinilm.distributed
import
DistConfig
from
infinilm.lib
import
_infinilm
from
infinilm.lib
import
_infinilm
...
@@ -18,6 +18,7 @@ class GenerationConfig:
...
@@ -18,6 +18,7 @@ class GenerationConfig:
top_p
:
float
=
1.0
top_p
:
float
=
1.0
eos_token_id
:
list
[
int
]
|
None
=
None
eos_token_id
:
list
[
int
]
|
None
=
None
stop_on_eos
:
bool
=
True
class
InferEngine
(
_infinilm
.
InferEngine
):
class
InferEngine
(
_infinilm
.
InferEngine
):
...
@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -42,6 +43,8 @@ class InferEngine(_infinilm.InferEngine):
self
.
use_cache
=
False
self
.
use_cache
=
False
self
.
enable_paged_attn
=
isinstance
(
cache_config
,
PagedKVCacheConfig
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
...
@@ -93,15 +96,11 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -93,15 +96,11 @@ class InferEngine(_infinilm.InferEngine):
else
:
else
:
eos_token_id
=
generation_config
.
eos_token_id
eos_token_id
=
generation_config
.
eos_token_id
# TODO: Remove the `to_numpy` calls and simplify the corresponding code.
past_seq_len
=
0
batch_size
,
seq_len
=
input_ids
.
shape
[:
2
]
position_ids
=
infinicore
.
from_list
(
[
list
(
range
(
0
,
seq_len
))
for
_
in
range
(
batch_size
)],
dtype
=
infinicore
.
int64
)
cache_lengths
=
infinicore
.
from_list
([
0
],
dtype
=
infinicore
.
int64
)
output_ids
=
[]
output_ids
=
[]
initial_batch_size
,
initial_seqlen
=
input_ids
.
shape
[:
2
]
seq_len
=
initial_seqlen
batch_size
=
initial_batch_size
if
batch_size
!=
1
and
generation_config
.
max_new_tokens
is
None
:
if
batch_size
!=
1
and
generation_config
.
max_new_tokens
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -111,14 +110,76 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -111,14 +110,76 @@ class InferEngine(_infinilm.InferEngine):
if
_measure_and_log_time
:
if
_measure_and_log_time
:
time_measurements
=
[]
time_measurements
=
[]
for
_
in
range
(
0
,
generation_config
.
max_new_tokens
):
for
iter
in
range
(
0
,
generation_config
.
max_new_tokens
):
if
_measure_and_log_time
:
if
_measure_and_log_time
:
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
batch_size
,
seq_len
=
input_ids
.
shape
[:
2
]
if
self
.
enable_paged_attn
:
input_ids
=
input_ids
.
view
([
1
,
batch_size
*
seq_len
])
position_ids
=
infinicore
.
from_list
(
list
(
range
(
past_seq_len
,
past_seq_len
+
seq_len
))
*
batch_size
,
dtype
=
infinicore
.
int64
,
)
cache_lengths
=
infinicore
.
from_list
(
[
past_seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
)
input_lengths
=
infinicore
.
from_list
(
[
seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
)
input_offsets
=
infinicore
.
from_list
(
[
seq_len
*
i
for
i
in
range
(
batch_size
)],
dtype
=
infinicore
.
int64
)
block_tables
=
infinicore
.
from_list
(
[
[
i
*
batch_size
+
b
for
i
in
range
((
past_seq_len
+
seq_len
+
15
)
//
16
)
]
for
b
in
range
(
batch_size
)
],
dtype
=
infinicore
.
int64
,
)
slot_mapping
=
infinicore
.
from_list
(
[
((
past_seq_len
+
i
+
15
)
//
16
)
*
batch_size
+
b
+
(
past_seq_len
+
i
+
15
)
%
16
for
i
in
range
(
seq_len
)
for
b
in
range
(
batch_size
)
],
dtype
=
infinicore
.
int64
,
)
else
:
position_ids
=
infinicore
.
from_list
(
[
list
(
range
(
past_seq_len
,
past_seq_len
+
seq_len
))
for
_
in
range
(
batch_size
)
],
dtype
=
infinicore
.
int64
,
)
cache_lengths
=
infinicore
.
from_list
(
[
past_seq_len
],
dtype
=
infinicore
.
int64
)
input_lengths
=
infinicore
.
from_list
(
[
seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
)
input_offsets
=
infinicore
.
from_list
(
[
seq_len
*
i
for
i
in
range
(
batch_size
)],
dtype
=
infinicore
.
int64
)
block_tables
=
None
slot_mapping
=
None
output_id
=
self
(
output_id
=
self
(
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cache_lengths
=
cache_lengths
,
cache_lengths
=
cache_lengths
,
input_lengths
=
input_lengths
,
input_offsets
=
input_offsets
,
block_tables
=
block_tables
,
slot_mapping
=
slot_mapping
,
temperature
=
generation_config
.
temperature
,
temperature
=
generation_config
.
temperature
,
top_k
=
generation_config
.
top_k
,
top_k
=
generation_config
.
top_k
,
top_p
=
generation_config
.
top_p
,
top_p
=
generation_config
.
top_p
,
...
@@ -127,24 +188,16 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -127,24 +188,16 @@ class InferEngine(_infinilm.InferEngine):
output_ids
.
append
(
output_id
)
output_ids
.
append
(
output_id
)
if
(
if
(
generation_config
.
max_new_tokens
is
not
None
generation_config
.
stop_on_eos
and
generation_config
.
max_new_tokens
is
not
None
and
output_id
.
to_numpy
()[
0
]
in
eos_token_id
and
output_id
.
to_numpy
()[
0
]
in
eos_token_id
):
):
break
break
seq_len
=
position_ids
.
shape
[
-
1
]
input_ids
=
infinicore
.
from_list
(
input_ids
=
infinicore
.
from_list
(
[[
output_id
]
for
output_id
in
output_id
.
to_numpy
().
tolist
()]
[[
output_id
]
for
output_id
in
output_id
.
to_numpy
().
tolist
()]
)
)
position_ids
=
infinicore
.
from_list
(
past_seq_len
=
past_seq_len
+
seq_len
[
1
for
_
in
range
(
batch_size
)],
dtype
=
position_ids
.
dtype
,
device
=
position_ids
.
device
,
).
view
((
batch_size
,
1
))
+
position_ids
.
narrow
(
1
,
seq_len
-
1
,
1
)
cache_lengths
+=
infinicore
.
from_list
(
[
seq_len
],
dtype
=
cache_lengths
.
dtype
,
device
=
cache_lengths
.
device
)
if
_measure_and_log_time
:
if
_measure_and_log_time
:
end_time
=
time
.
perf_counter
()
end_time
=
time
.
perf_counter
()
...
@@ -156,23 +209,21 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -156,23 +209,21 @@ class InferEngine(_infinilm.InferEngine):
f
"
\n\n\n
Generation completed in
{
round
(
sum
(
time_measurements
)
*
1000
,
2
)
}
ms"
f
"
\n\n\n
Generation completed in
{
round
(
sum
(
time_measurements
)
*
1000
,
2
)
}
ms"
)
)
print
(
print
(
f
" Batchsize=
{
batch_size
}
Per_Batch_Input_Len=
{
seq
_
len
}
Per_Batch_New_Tokens=
{
len
(
time_measurements
)
}
\n
"
f
" Batchsize=
{
initial_
batch_size
}
Per_Batch_Input_Len=
{
initial_
seqlen
}
Per_Batch_New_Tokens=
{
len
(
time_measurements
)
}
\n
"
)
)
print
(
print
(
f
" Prefill TTFT:
{
round
(
time_measurements
[
0
],
2
)
}
ms Throughput:
{
round
((
batch_size
*
seq
_
len
)
/
time_measurements
[
0
],
2
)
}
tok/s
\n
"
,
f
" Prefill TTFT:
{
round
(
time_measurements
[
0
],
2
)
}
ms Throughput:
{
round
((
initial_
batch_size
*
initial_
seqlen
)
/
time_measurements
[
0
],
2
)
}
tok/s
\n
"
,
)
)
if
len
(
time_measurements
)
>
1
:
if
len
(
time_measurements
)
>
1
:
print
(
print
(
f
" Decode Avg ITL:
{
round
(
sum
(
time_measurements
[
1
:])
*
1000
/
(
len
(
time_measurements
)
-
1
),
2
)
}
ms Throughput:
{
round
((
batch_size
*
(
len
(
time_measurements
)
-
1
))
/
sum
(
time_measurements
[
1
:]),
2
)
}
tok/s
\n
"
,
f
" Decode Avg ITL:
{
round
(
sum
(
time_measurements
[
1
:])
*
1000
/
(
len
(
time_measurements
)
-
1
),
2
)
}
ms Throughput:
{
round
((
initial_
batch_size
*
(
len
(
time_measurements
)
-
1
))
/
sum
(
time_measurements
[
1
:]),
2
)
}
tok/s
\n
"
,
)
)
return
output_ids
return
output_ids
def
reset_cache
(
self
,
batch_size
:
int
,
initial_capacity
:
int
=
1024
):
def
reset_cache
(
self
,
cache_config
):
infinicore
.
sync_device
()
infinicore
.
sync_device
()
self
.
enable_paged_attn
=
isinstance
(
cache_config
,
PagedKVCacheConfig
)
cache_config
=
StaticKVCacheConfig
(
batch_size
,
initial_capacity
)
super
().
reset_cache
(
cache_config
)
super
().
reset_cache
(
cache_config
)
def
state_dict_keyname
(
self
):
def
state_dict_keyname
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment