Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
c1a3ab29
Unverified
Commit
c1a3ab29
authored
Jan 14, 2026
by
Haojie Wang
Committed by
GitHub
Jan 14, 2026
Browse files
Merge pull request #173 from InfiniTensor/issue/168
Issue/168 InfiniLM接入paged attention接口
parents
96e53dbb
09ab8fa4
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
406 additions
and
134 deletions
+406
-134
csrc/cache/kv_cache.cpp
csrc/cache/kv_cache.cpp
+78
-17
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+40
-6
csrc/engine/infer_engine.cpp
csrc/engine/infer_engine.cpp
+17
-2
csrc/engine/rank_worker.cpp
csrc/engine/rank_worker.cpp
+18
-13
csrc/engine/rank_worker.hpp
csrc/engine/rank_worker.hpp
+4
-4
csrc/models/infinilm_model.hpp
csrc/models/infinilm_model.hpp
+4
-4
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+119
-22
csrc/models/llama/llama_attention.hpp
csrc/models/llama/llama_attention.hpp
+20
-3
csrc/models/llama/llama_decoder_layer.cpp
csrc/models/llama/llama_decoder_layer.cpp
+3
-3
csrc/models/llama/llama_decoder_layer.hpp
csrc/models/llama/llama_decoder_layer.hpp
+2
-2
csrc/models/llama/llama_for_causal_lm.cpp
csrc/models/llama/llama_for_causal_lm.cpp
+4
-4
csrc/models/llama/llama_model.cpp
csrc/models/llama/llama_model.cpp
+4
-12
csrc/models/llama/llama_model.hpp
csrc/models/llama/llama_model.hpp
+5
-5
csrc/pybind11/cache/cache.hpp
csrc/pybind11/cache/cache.hpp
+3
-3
csrc/pybind11/engine/engine.hpp
csrc/pybind11/engine/engine.hpp
+41
-18
examples/bench.py
examples/bench.py
+7
-1
examples/jiuge.py
examples/jiuge.py
+31
-11
python/infinilm/auto_config.py
python/infinilm/auto_config.py
+2
-0
python/infinilm/cache/__init__.py
python/infinilm/cache/__init__.py
+2
-2
python/infinilm/cache/cache.py
python/infinilm/cache/cache.py
+2
-2
No files found.
csrc/cache/kv_cache.cpp
View file @
c1a3ab29
#include "kv_cache.hpp"
#include "../utils.hpp"
#include "infinicore/ops.hpp"
#include <stdexcept>
namespace
infinilm
::
cache
{
...
...
@@ -80,12 +80,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache
::
update
(
size_t
layer_idx
,
const
infinicore
::
Tensor
&
k
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
cach
e_lengths
)
{
const
infinicore
::
Tensor
&
past_sequenc
e_lengths
)
{
ASSERT
(
layer_idx
<
rank_num_layers_
);
auto
batch_size
=
k
->
size
(
0
);
auto
update_len
=
k
->
size
(
2
);
size_t
cache_pos
=
reinterpret_cast
<
int64_t
*>
(
cach
e_lengths
->
to
(
infinicore
::
Device
::
cpu
())
->
data
())[
0
];
size_t
cache_pos
=
reinterpret_cast
<
int64_t
*>
(
past_sequenc
e_lengths
->
to
(
infinicore
::
Device
::
cpu
())
->
data
())[
0
];
auto
result_len
=
cache_pos
+
update_len
;
ASSERT
(
result_len
<=
cache_len_
);
...
...
@@ -111,9 +111,9 @@ StaticKVCache::update(size_t layer_idx,
// PagedKVCacheConfig
// ==========================
PagedKVCacheConfig
::
PagedKVCacheConfig
(
size_t
max_kv_memory_byte
s
,
size_t
num_block
s
,
size_t
block_size
)
:
max_kv_memory_bytes_
(
max_kv_memory_byte
s
),
:
num_blocks_
(
num_block
s
),
block_size_
(
block_size
)
{
}
...
...
@@ -123,8 +123,8 @@ PagedKVCacheConfig::unique_copy() const {
}
size_t
PagedKVCacheConfig
::
max_kv_memory_byte
s
()
const
{
return
max_kv_memory_byte
s_
;
PagedKVCacheConfig
::
num_block
s
()
const
{
return
num_block
s_
;
}
size_t
...
...
@@ -151,15 +151,8 @@ PagedKVCache::PagedKVCache(
num_rank_v_heads_
(
num_v_heads
/
rank_info
.
tp_size
),
rank_num_layers_
(
num_layers
),
dtype_
(
dtype
),
num_blocks_per_layer_
(
config
.
num_blocks
()),
block_size_
(
config
.
block_size
())
{
num_blocks_per_layer_
=
config
.
max_kv_memory_bytes
()
/
(
k_dim
*
num_rank_k_heads_
+
v_dim
*
num_rank_v_heads_
)
/
block_size_
/
infinicore
::
dsize
(
dtype_
);
if
(
num_blocks_per_layer_
==
0
)
{
throw
std
::
runtime_error
(
"Not enough memory for KV cache"
);
}
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
k_caches_
=
infinicore
::
Tensor
::
empty
(
{
rank_num_layers_
,
...
...
@@ -187,11 +180,79 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
slot_mapping
)
{
auto
&&
[
k_cache_layer
,
v_cache_layer
]
=
get_paged_kv
(
layer_idx
);
infinicore
::
op
::
paged_caching_
(
k_cache_layer
,
v_cache_layer
,
k
,
v
,
slot_mapping
);
return
{
k_cache_layer
,
v_cache_layer
};
}
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
PagedKVCache
::
get_paged_kv
(
size_t
layer_idx
)
{
auto
k_cache_layer
=
k_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
v_cache_layer
=
v_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
return
{
k_cache_layer
,
v_cache_layer
};
}
/// @todo: implement paged cache update here
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
PagedKVCache
::
get_contiguous_kv
(
size_t
layer_idx
,
const
infinicore
::
Tensor
block_tables
,
const
infinicore
::
Tensor
cache_lens
,
const
infinicore
::
Tensor
input_offsets
,
size_t
request_id
)
{
ASSERT_EQ
(
block_tables
->
dtype
(),
infinicore
::
DataType
::
I64
);
ASSERT_EQ
(
cache_lens
->
dtype
(),
infinicore
::
DataType
::
I64
);
ASSERT_EQ
(
input_offsets
->
dtype
(),
infinicore
::
DataType
::
I64
);
return
{
k_cache_layer
,
v_cache_layer
};
auto
nreq
=
block_tables
->
size
(
0
);
auto
block_tables_cpu
=
block_tables
->
to
(
infinicore
::
Device
::
cpu
());
auto
cache_lens_cpu
=
cache_lens
->
to
(
infinicore
::
Device
::
cpu
());
auto
input_offsets_cpu
=
input_offsets
->
to
(
infinicore
::
Device
::
cpu
());
infinicore
::
context
::
syncDevice
();
// [num_blocks, num_rank_v_heads, block_size, v_dim]
auto
&&
[
k_cache_layer
,
v_cache_layer
]
=
get_paged_kv
(
layer_idx
);
auto
req
=
request_id
;
auto
cache_lens_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
cache_lens_cpu
->
data
());
auto
input_offsets_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input_offsets_cpu
->
data
());
int64_t
total_len
=
cache_lens_ptr
[
req
]
+
(
input_offsets_ptr
[
req
+
1
]
-
input_offsets_ptr
[
req
]);
auto
full_k
=
infinicore
::
Tensor
::
empty
(
{
num_rank_k_heads_
,
(
size_t
)
total_len
,
k_dim_
},
k_cache_layer
->
dtype
(),
k_cache_layer
->
device
());
auto
full_v
=
infinicore
::
Tensor
::
empty
(
{
num_rank_v_heads_
,
(
size_t
)
total_len
,
v_dim_
},
v_cache_layer
->
dtype
(),
v_cache_layer
->
device
());
size_t
nblocks
=
total_len
/
block_size_
;
size_t
r
=
total_len
%
block_size_
;
for
(
size_t
b
=
0
;
b
<
nblocks
;
b
++
)
{
size_t
bid
=
*
((
int64_t
*
)(
block_tables_cpu
->
narrow
({{
0
,
req
,
1
},
{
1
,
b
,
1
}})
->
data
()));
full_k
->
narrow
({{
1
,
b
*
block_size_
,
block_size_
}})
->
copy_from
(
k_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
));
full_v
->
narrow
({{
1
,
b
*
block_size_
,
block_size_
}})
->
copy_from
(
v_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
));
}
if
(
r
>
0
)
{
size_t
bid
=
*
((
int64_t
*
)(
block_tables_cpu
->
narrow
({{
0
,
req
,
1
},
{
1
,
nblocks
,
1
}})
->
data
()));
full_k
->
narrow
({{
1
,
nblocks
*
block_size_
,
r
}})
->
copy_from
(
k_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
)
->
narrow
({{
1
,
0
,
r
}}));
full_v
->
narrow
({{
1
,
nblocks
*
block_size_
,
r
}})
->
copy_from
(
v_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
)
->
narrow
({{
1
,
0
,
r
}}));
}
return
{
full_k
,
full_v
};
}
}
// namespace infinilm::cache
csrc/cache/kv_cache.hpp
View file @
c1a3ab29
...
...
@@ -61,7 +61,7 @@ public:
update
(
size_t
layer_idx
,
const
infinicore
::
Tensor
&
k
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
cach
e_lengths
);
const
infinicore
::
Tensor
&
past_sequenc
e_lengths
);
~
StaticKVCache
()
override
=
default
;
...
...
@@ -85,15 +85,15 @@ private:
class
PagedKVCacheConfig
final
:
public
CacheConfig
{
public:
PagedKVCacheConfig
(
size_t
max_kv_memory_byte
s
,
size_t
num_block
s
,
size_t
block_size
=
16
);
std
::
unique_ptr
<
CacheConfig
>
unique_copy
()
const
override
;
size_t
max_kv_memory_byte
s
()
const
;
size_t
num_block
s
()
const
;
size_t
block_size
()
const
;
private:
size_t
max_kv_memory_byte
s_
;
size_t
num_block
s_
;
size_t
block_size_
;
};
...
...
@@ -113,7 +113,7 @@ public:
/**
* @brief Update Paged KV cache at a given layer given slot info for each token.
*
* @param layer_idx Which
transformer
layer
* @param layer_idx Which
paged attention
layer
* @param k [num_rank_k_heads, seq_len, k_dim]
* @param v [num_rank_v_heads, seq_len, v_dim]
* @param slot_mapping [seq_len]
...
...
@@ -128,7 +128,41 @@ public:
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
slot_mapping
);
~
PagedKVCache
()
override
=
default
;
/**
* @brief Get Paged KV cache at a given layer.
*
* @param layer_idx Which paged attention layer
*
* @return (full_k, full_v)
* full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
* full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
*/
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
get_paged_kv
(
size_t
layer_idx
);
/**
* @brief Get contiguous KV cache at a given layer, given the request info
* among a continuous request batch.
*
* @param layer_idx Which paged attention layer
* @param block_tables [num_requests, max_blocks_per_request]
* @param cache_lens [num_requests]
* @param input_offsets [num_requests + 1]
* @param request_id Which request among a continuous batch of requests
*
* @return (full_k, full_v)
* full_k: [num_rank_k_heads, total_len, k_dim]
* full_v: [num_rank_v_heads, total_len, v_dim]
*/
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
get_contiguous_kv
(
size_t
layer_idx
,
const
infinicore
::
Tensor
block_tables
,
const
infinicore
::
Tensor
cache_lens
,
const
infinicore
::
Tensor
input_offsets
,
size_t
request_id
=
0
);
~
PagedKVCache
()
override
=
default
;
private:
infinicore
::
Size
k_dim_
;
...
...
csrc/engine/infer_engine.cpp
View file @
c1a3ab29
...
...
@@ -56,8 +56,23 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------
// forward
//------------------------------------------------------
infinilm
::
InfinilmModel
::
Input
InferEngine
::
Input
::
to_model_input
()
const
{
return
{
input_ids
,
position_ids
,
cache_lengths
,
input_lengths
,
input_offsets
,
block_tables
,
slot_mapping
};
infinilm
::
InfinilmModel
::
Input
InferEngine
::
Input
::
to_model_input
(
infinicore
::
Device
device
)
const
{
auto
to_device
=
[
&
](
const
std
::
optional
<
infinicore
::
Tensor
>
&
t
)
->
std
::
optional
<
infinicore
::
Tensor
>
{
return
t
.
has_value
()
?
t
.
value
()
->
to
(
device
)
:
t
;
};
return
{
input_ids
,
// @todo: on device in the future
to_device
(
position_ids
),
past_sequence_lengths
,
// @todo: on device in the future
to_device
(
total_sequence_lengths
),
to_device
(
input_offsets
),
to_device
(
block_tables
),
to_device
(
slot_mapping
),
};
}
InferEngine
::
Output
InferEngine
::
forward
(
const
InferEngine
::
Input
&
input
)
{
...
...
csrc/engine/rank_worker.cpp
View file @
c1a3ab29
...
...
@@ -188,7 +188,7 @@ void RankWorker::thread_loop() {
Command
local_cmd
=
Command
::
INIT
;
std
::
string
local_param_name
;
infinicore
::
Tensor
local_param
;
InfinilmModel
::
Input
local_args
;
Input
local_args
;
std
::
unique_ptr
<
cache
::
CacheConfig
>
local_cache_config
;
// Wait for a job or exit
...
...
@@ -206,7 +206,7 @@ void RankWorker::thread_loop() {
local_param_name
=
pending_param_name_
;
local_param
=
pending_param_
;
}
else
if
(
local_cmd
==
Command
::
RUN
)
{
local_args
=
pending_args_
.
to_model_input
()
;
local_args
=
pending_args_
;
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
if
(
pending_cache_config_
!=
nullptr
)
{
local_cache_config
=
pending_cache_config_
->
unique_copy
();
...
...
@@ -244,23 +244,28 @@ void RankWorker::thread_loop() {
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
auto
logits
{
model_
->
forward
(
local_args
).
logits
};
auto
model_args
=
local_args
.
to_model_input
(
rank_info_
.
device
);
// Forward calculation
auto
logits
{
model_
->
forward
(
model_args
).
logits
};
// Random sampling (rank 0 only)
if
(
rank_info_
.
tp_rank
==
0
)
{
// Perform random sampling.
auto
temperature
{
pending_args_
.
temperature
};
auto
top_p
{
pending_args_
.
top_p
};
auto
top_k
{
pending_args_
.
top_k
};
auto
random_val
{
pending_args_
.
random_val
};
auto
temperature
{
local_args
.
temperature
};
auto
top_p
{
local_args
.
top_p
};
auto
top_k
{
local_args
.
top_k
};
auto
random_val
{
local_args
.
random_val
};
const
auto
&
logits_shape
{
logits
->
shape
()};
const
auto
&
batch_size
{
logits_shape
[
0
]};
const
auto
&
vocab_size
{
logits_shape
[
2
]};
const
auto
&
total_len
{
logits_shape
[
1
]};
const
auto
&
batch_size
{
logits_shape
[
0
]};
auto
n_req
=
local_args
.
input_offsets
.
value
()
->
size
(
0
)
-
1
;
int64_t
*
input_offsets
=
(
int64_t
*
)
local_args
.
input_offsets
.
value
()
->
data
();
auto
output_ids
{
infinicore
::
Tensor
::
empty
({
batch_size
},
infinicore
::
DataType
::
I
32
,
rank_info_
.
device
)};
auto
output_ids
{
infinicore
::
Tensor
::
empty
({
n_req
},
infinicore
::
DataType
::
I
64
,
rank_info_
.
device
)};
for
(
auto
i
{
decltype
(
batch_size
)(
0
)};
i
<
batch_size
;
++
i
)
{
auto
score
{
logits
->
narrow
({{
0
,
i
,
1
}})
->
view
({
vocab_size
})};
for
(
auto
i
{
decltype
(
n_req
)(
0
)};
i
<
n_req
;
++
i
)
{
auto
score
{
logits
->
view
({
batch_size
*
total_len
,
vocab_size
})
->
narrow
({{
0
,
size_t
(
input_offsets
[
i
+
1
]
-
1
)
,
1
}})
->
view
({
vocab_size
})};
auto
out
{
output_ids
->
narrow
({{
0
,
i
,
1
}})
->
view
({})};
infinicore
::
op
::
random_sample_
(
out
,
score
,
random_val
,
top_p
,
top_k
,
temperature
);
...
...
csrc/engine/rank_worker.hpp
View file @
c1a3ab29
...
...
@@ -29,9 +29,9 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std
::
optional
<
infinicore
::
Tensor
>
position_ids
;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
;
///
Input
Lengths
o
f each request
in a continous-batched
sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
;
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
;
///
ToTal
Lengths f
or
each request sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
...
...
@@ -47,7 +47,7 @@ public:
float
random_val
{
0.1
};
infinilm
::
InfinilmModel
::
Input
to_model_input
()
const
;
infinilm
::
InfinilmModel
::
Input
to_model_input
(
infinicore
::
Device
device
)
const
;
};
struct
Output
{
...
...
csrc/models/infinilm_model.hpp
View file @
c1a3ab29
...
...
@@ -23,10 +23,10 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std
::
optional
<
infinicore
::
Tensor
>
position_ids
;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
;
///
Input
Lengths
o
f each request
in a continous-batched
sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
;
///
ToTal
Lengths f
or
each request sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests
+ 1
]`.
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
std
::
optional
<
infinicore
::
Tensor
>
block_tables
;
...
...
csrc/models/llama/llama_attention.cpp
View file @
c1a3ab29
#include "llama_attention.hpp"
#include "../../utils.hpp"
#include "infinicore/nn/linear.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/ops.hpp"
...
...
@@ -43,6 +44,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
}
else
{
throw
std
::
runtime_error
(
"num_attention_heads / tp_size error."
);
}
scaling_
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
// Initialize projection layers
INFINILM_QKV_LINEAR_INIT
(
qkv_proj
,
"q_proj"
,
"k_proj"
,
"v_proj"
,
hidden_size_
,
head_dim_
,
config
.
num_attention_heads
,
config
.
num_key_value_heads
,
use_bias_
,
...
...
@@ -52,17 +54,11 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
dtype
,
device
,
tp_rank
,
tp_size
,
rank_info
.
comm
);
}
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
if
(
!
rotary_emb_
)
{
throw
std
::
runtime_error
(
"LlamaAttention: rotary_emb not configured"
);
}
infinicore
::
Tensor
LlamaAttention
::
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
)
const
{
// Input shape: [batch, seq_len, hidden_size]
auto
hidden_states_mutable
=
hidden_states
;
auto
shape
=
hidden_states
->
shape
();
...
...
@@ -73,7 +69,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions
...
...
@@ -111,16 +106,9 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
k_total
=
k_permuted
;
v_total
=
v_permuted
;
}
else
if
(
auto
static_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
StaticKVCache
>
(
kv_cache
))
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
static_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
cache_lengths
.
value
());
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
}
else
if
(
auto
paged_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
PagedKVCache
>
(
kv_cache
))
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
paged_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
slot_mapping
.
value
());
auto
[
k_total_tmp
,
v_total_tmp
]
=
static_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
past_sequence_lengths
.
value
());
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
/// @todo Implement paged attention here.
throw
std
::
runtime_error
(
"LlamaAttention: Paged attention not implemented"
);
}
else
{
throw
std
::
runtime_error
(
"LlamaAttention: Unsupported kvcache type"
);
}
...
...
@@ -134,8 +122,7 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
auto
K_transposed
=
K
->
permute
({
0
,
2
,
1
});
// [bs * n_kv_head, head_dim, total_seq_len]
float
scaling
=
1.0
f
/
std
::
sqrt
(
static_cast
<
float
>
(
head_dim_
));
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling
);
// [bs * n_kv_head, ng * seq_len, total_seq_len]
auto
attn_weight
=
infinicore
::
op
::
matmul
(
Q
,
K_transposed
,
scaling_
);
// [bs * n_kv_head, ng * seq_len, total_seq_len]
auto
attn_weight_softmax
=
attn_weight
->
view
({
batch_size
*
num_attention_heads_
,
seq_len
,
total_seq_len
});
infinicore
::
op
::
causal_softmax_
(
attn_weight_softmax
,
attn_weight_softmax
);
...
...
@@ -152,6 +139,116 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
return
output
;
}
infinicore
::
Tensor
LlamaAttention
::
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
paged_kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
ASSERT
(
block_tables
.
has_value
());
ASSERT
(
slot_mapping
.
has_value
());
// Input shape: [batch, seq_len, hidden_size]
auto
hidden_states_mutable
=
hidden_states
;
auto
shape
=
hidden_states
->
shape
();
size_t
batch_size
=
shape
[
0
];
size_t
seq_len
=
shape
[
1
];
// Only support batchsize==1, all requests should be flattened along seqlen dimension
ASSERT_EQ
(
batch_size
,
1
);
// Decode only if total_len == num_requests
bool
is_prefill
=
(
seq_len
!=
total_sequence_lengths
.
value
()
->
shape
()[
0
]);
// 1. Project Q, K, V
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
// 2. Reshape for multi-head attention
// Reshape Q, K, V to include batch dimension
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
// The view operation requires the tensor to be contiguous in the required dimensions
auto
q_reshaped
=
q
->
view
({
seq_len
,
num_attention_heads_
,
head_dim_
});
auto
k_reshaped
=
k
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
auto
v_reshaped
=
v
->
view
({
seq_len
,
num_key_value_heads_
,
head_dim_
});
// 3. Prepare position_ids for RoPE - align with Python pattern
auto
pos_shape
=
position_ids
->
shape
();
infinicore
::
Tensor
pos_ids_for_rope
=
position_ids
;
if
(
pos_shape
.
size
()
==
2
)
{
auto
pos_narrowed
=
position_ids
->
narrow
({{
0
,
0
,
1
}});
pos_ids_for_rope
=
pos_narrowed
->
view
({
pos_shape
[
1
]});
}
else
if
(
pos_shape
.
size
()
==
1
)
{
pos_ids_for_rope
=
position_ids
;
}
else
{
throw
std
::
runtime_error
(
"Unexpected position_ids shape"
);
}
// 4. Apply RoPE to Q and K
rotary_emb_
->
forward
(
q_reshaped
,
pos_ids_for_rope
,
true
);
// [bs, seq_len, n_q_head, head_dim]
rotary_emb_
->
forward
(
k_reshaped
,
pos_ids_for_rope
,
true
);
// [bs, seq_len, n_kv_head, head_dim]
// 5. Prepare KV caches
// Ensure contiguous after permute for F16 compatibility with cache operations
auto
[
k_total
,
v_total
]
=
paged_kv_cache
->
update
(
layer_idx_
,
k_reshaped
,
v_reshaped
,
slot_mapping
.
value
());
// 6. Compute attention
infinicore
::
Tensor
attn_output
=
infinicore
::
Tensor
::
empty
({
seq_len
,
num_attention_heads_
,
head_dim_
},
q_reshaped
->
dtype
(),
q_reshaped
->
device
());
if
(
is_prefill
)
{
infinicore
::
op
::
paged_attention_prefill_
(
attn_output
,
q_reshaped
,
k_total
,
v_total
,
block_tables
.
value
(),
total_sequence_lengths
.
value
(),
input_offsets
.
value
(),
std
::
nullopt
,
scaling_
);
}
else
{
infinicore
::
op
::
paged_attention_
(
attn_output
,
q_reshaped
,
k_total
,
v_total
,
block_tables
.
value
(),
total_sequence_lengths
.
value
(),
std
::
nullopt
,
scaling_
);
}
// 7. Project output
attn_output
=
attn_output
->
view
({
1
,
seq_len
,
num_attention_heads_
*
head_dim_
});
return
o_proj_
->
forward
(
attn_output
);
}
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
if
(
!
rotary_emb_
)
{
throw
std
::
runtime_error
(
"LlamaAttention: rotary_emb not configured"
);
}
infinicore
::
Tensor
output
;
if
(
auto
paged_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
PagedKVCache
>
(
kv_cache
))
{
output
=
forward_paged_
(
hidden_states
,
position_ids
,
paged_kv_cache
,
total_sequence_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
}
else
{
output
=
forward_
(
hidden_states
,
position_ids
,
kv_cache
,
past_sequence_lengths
,
total_sequence_lengths
);
}
return
output
;
}
void
LlamaAttention
::
set_rotary_emb
(
const
std
::
shared_ptr
<
infinicore
::
nn
::
RoPE
>
&
rotary_emb
)
{
rotary_emb_
=
rotary_emb
;
}
...
...
csrc/models/llama/llama_attention.hpp
View file @
c1a3ab29
...
...
@@ -51,11 +51,11 @@ public:
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mappin
)
const
;
std
::
optional
<
infinicore
::
Tensor
>
slot_mappin
g
)
const
;
/**
* @brief Get the layer index
...
...
@@ -73,6 +73,21 @@ public:
size_t
head_dim
()
const
{
return
head_dim_
;
}
size_t
hidden_size
()
const
{
return
hidden_size_
;
}
private:
infinicore
::
Tensor
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
)
const
;
infinicore
::
Tensor
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
protected:
// Projection layers
INFINICORE_NN_MODULE
(
infinilm
::
layers
::
QKVParallelLinear
,
qkv_proj
);
...
...
@@ -93,6 +108,8 @@ private:
bool
use_bias_
;
// Bias for Q/K/V projections
bool
use_output_bias_
;
// Bias for output projection (o_proj)
size_t
max_position_embeddings_
;
// For cache initialization (deprecated, kept for compatibility)
float
scaling_
;
};
}
// namespace infinilm::models::llama
csrc/models/llama/llama_decoder_layer.cpp
View file @
c1a3ab29
...
...
@@ -26,8 +26,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore
::
Tensor
LlamaDecoderLayer
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
...
...
@@ -38,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto
normed_states
=
input_layernorm_
->
forward
(
hidden_states
);
// 2. Self-attention with residual connection
auto
attn_output
=
self_attn_
->
forward
(
normed_states
,
position_ids
,
kv_cache
,
cache_lengths
,
input
_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
auto
attn_output
=
self_attn_
->
forward
(
normed_states
,
position_ids
,
kv_cache
,
past_sequence_lengths
,
total_sequence
_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
// Add residual: hidden_states = hidden_states + attn_output
auto
output
=
infinicore
::
op
::
add
(
residual
,
attn_output
);
...
...
csrc/models/llama/llama_decoder_layer.hpp
View file @
c1a3ab29
...
...
@@ -49,8 +49,8 @@ public:
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mappin
)
const
;
...
...
csrc/models/llama/llama_for_causal_lm.cpp
View file @
c1a3ab29
...
...
@@ -28,15 +28,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
LlamaForCausalLM
::
Output
LlamaForCausalLM
::
forward
(
const
Input
&
input
)
const
{
auto
input_ids
=
input
.
input_ids
.
value
();
auto
position_ids
=
input
.
position_ids
.
value
();
auto
cach
e_lengths
=
input
.
cach
e_lengths
;
auto
input
_length
s
=
input
.
input
_lengths
;
auto
past_sequenc
e_lengths
=
input
.
past_sequenc
e_lengths
;
auto
total_sequence
_length
=
input
.
total_sequence
_lengths
;
auto
input_offsets
=
input
.
input_offsets
;
auto
block_tables
=
input
.
block_tables
;
auto
slot_mapping
=
input
.
slot_mapping
;
// 1. Forward through base model to get hidden states
auto
position_ids_device
=
position_ids
->
to
(
device_
);
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids_device
,
cache_lengths
,
input
_length
s
,
input_offsets
,
block_tables
,
slot_mapping
);
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids
,
past_sequence_lengths
,
total_sequence
_length
,
input_offsets
,
block_tables
,
slot_mapping
);
// 2. Apply language modeling head to get logits
auto
logits
=
lm_head_
->
forward
(
hidden_states
);
...
...
csrc/models/llama/llama_model.cpp
View file @
c1a3ab29
...
...
@@ -45,8 +45,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore
::
Tensor
LlamaModel
::
forward
(
const
infinicore
::
Tensor
&
input_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
...
...
@@ -56,18 +56,10 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers
size_t
num_layers
=
layers_
.
size
();
for
(
size_t
i
=
0
;
i
<
num_layers
;
++
i
)
{
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
kv_cache_
,
cache_lengths
,
input
_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
kv_cache_
,
past_sequence_lengths
,
total_sequence
_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
}
// 3. Apply final layer normalization to last token only (aligns with transformers)
// Narrow to last token: [batch, seq_len, hidden_size] -> [batch, 1, hidden_size]
auto
shape
=
hidden_states
->
shape
();
size_t
seq_len
=
shape
[
1
];
auto
last_token
=
hidden_states
->
narrow
({{
1
,
seq_len
-
1
,
1
}});
auto
normalized_last_token
=
norm_
->
forward
(
last_token
);
return
normalized_last_token
;
return
norm_
->
forward
(
hidden_states
);
}
void
LlamaModel
::
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
{
...
...
csrc/models/llama/llama_model.hpp
View file @
c1a3ab29
...
...
@@ -48,15 +48,15 @@ public:
* @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used,
* and tokens from all requests are concatenated along seq_len dimension.
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param
cach
e_lengths Cache positions tensor of shape [n_req]
* @param
input_lengths Input lengths tensor in a continuous batch
of shape [n_req]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req]
* @param
past_sequenc
e_lengths Cache positions tensor of shape [n_req]
* @param
total_sequence_lengths Total sequence lengths tensor
of shape [n_req]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req
+ 1
]
* @return Output tensor of shape [batch, seq_len, hidden_size]
*/
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
input_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
...
...
csrc/pybind11/cache/cache.hpp
View file @
c1a3ab29
...
...
@@ -36,11 +36,11 @@ inline void bind_cache(py::module &m) {
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCacheConfig
>>
(
m
,
"PagedKVCacheConfig"
)
.
def
(
py
::
init
<
size_t
,
size_t
>
(),
py
::
arg
(
"
max_kv_memory_byte
s"
),
py
::
arg
(
"
num_block
s"
),
py
::
arg
(
"block_size"
)
=
16
)
.
def
(
"
max_kv_memory_byte
s"
,
&
infinilm
::
cache
::
PagedKVCacheConfig
::
max_kv_memory_byte
s
)
"
num_block
s"
,
&
infinilm
::
cache
::
PagedKVCacheConfig
::
num_block
s
)
.
def
(
"block_size"
,
&
infinilm
::
cache
::
PagedKVCacheConfig
::
block_size
)
...
...
csrc/pybind11/engine/engine.hpp
View file @
c1a3ab29
...
...
@@ -80,28 +80,48 @@ inline void bind_infer_engine(py::module &m) {
py
::
init
([](
std
::
optional
<
infinicore
::
Tensor
>
input_ids
,
std
::
optional
<
infinicore
::
Tensor
>
position_ids
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence
_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
,
py
::
kwargs
kwargs
)
{
auto
input
{
InferEngine
::
Input
{
InferEngine
::
Input
input
{
std
::
move
(
input_ids
),
std
::
move
(
position_ids
),
std
::
move
(
cache_lengths
),
std
::
move
(
past_sequence_lengths
),
std
::
move
(
total_sequence_lengths
),
std
::
move
(
input_offsets
),
std
::
move
(
block_tables
),
std
::
move
(
slot_mapping
)}};
std
::
move
(
slot_mapping
),
};
if
(
kwargs
)
{
if
(
kwargs
.
contains
(
"temperature"
))
{
input
.
temperature
=
kwargs
[
"temperature"
].
cast
<
float
>
();
}
if
(
kwargs
.
contains
(
"top_k"
))
{
input
.
top_k
=
kwargs
[
"top_k"
].
cast
<
int
>
();
// Explicit defaults
input
.
temperature
=
1.0
f
;
input
.
top_p
=
1.0
f
;
input
.
top_k
=
1
;
// Allowed keyword arguments
static
const
std
::
unordered_set
<
std
::
string
>
allowed_kwargs
=
{
"temperature"
,
"top_p"
,
"top_k"
,
};
for
(
auto
&
item
:
kwargs
)
{
const
std
::
string
key
=
py
::
cast
<
std
::
string
>
(
item
.
first
);
if
(
allowed_kwargs
.
find
(
key
)
==
allowed_kwargs
.
end
())
{
throw
py
::
value_error
(
"InferEngine.Input got an unexpected keyword argument '"
+
key
+
"'"
);
}
if
(
kwargs
.
contains
(
"top_p"
))
{
input
.
top_p
=
kwargs
[
"top_p"
].
cast
<
float
>
();
if
(
key
==
"temperature"
)
{
input
.
temperature
=
py
::
cast
<
float
>
(
item
.
second
);
}
else
if
(
key
==
"top_p"
)
{
input
.
top_p
=
py
::
cast
<
float
>
(
item
.
second
);
}
else
if
(
key
==
"top_k"
)
{
input
.
top_k
=
py
::
cast
<
int
>
(
item
.
second
);
}
}
...
...
@@ -109,18 +129,21 @@ inline void bind_infer_engine(py::module &m) {
}),
py
::
arg
(
"input_ids"
)
=
std
::
nullopt
,
py
::
arg
(
"position_ids"
)
=
std
::
nullopt
,
py
::
arg
(
"
cach
e_lengths"
)
=
std
::
nullopt
,
py
::
arg
(
"
input
_lengths"
)
=
std
::
nullopt
,
py
::
arg
(
"
past_sequenc
e_lengths"
)
=
std
::
nullopt
,
py
::
arg
(
"
total_sequence
_lengths"
)
=
std
::
nullopt
,
py
::
arg
(
"input_offsets"
)
=
std
::
nullopt
,
py
::
arg
(
"block_tables"
)
=
std
::
nullopt
,
py
::
arg
(
"slot_mapping"
)
=
std
::
nullopt
)
.
def_readwrite
(
"input_ids"
,
&
InferEngine
::
Input
::
input_ids
)
.
def_readwrite
(
"position_ids"
,
&
InferEngine
::
Input
::
position_ids
)
.
def_readwrite
(
"
cach
e_lengths"
,
&
InferEngine
::
Input
::
cach
e_lengths
)
.
def_readwrite
(
"
input
_lengths"
,
&
InferEngine
::
Input
::
input
_lengths
)
.
def_readwrite
(
"
past_sequenc
e_lengths"
,
&
InferEngine
::
Input
::
past_sequenc
e_lengths
)
.
def_readwrite
(
"
total_sequence
_lengths"
,
&
InferEngine
::
Input
::
total_sequence
_lengths
)
.
def_readwrite
(
"input_offsets"
,
&
InferEngine
::
Input
::
input_offsets
)
.
def_readwrite
(
"block_tables"
,
&
InferEngine
::
Input
::
block_tables
)
.
def_readwrite
(
"slot_mapping"
,
&
InferEngine
::
Input
::
slot_mapping
);
.
def_readwrite
(
"slot_mapping"
,
&
InferEngine
::
Input
::
slot_mapping
)
.
def_readwrite
(
"temperature"
,
&
InferEngine
::
Input
::
temperature
)
.
def_readwrite
(
"top_k"
,
&
InferEngine
::
Input
::
top_k
)
.
def_readwrite
(
"top_p"
,
&
InferEngine
::
Input
::
top_p
);
py
::
class_
<
InferEngine
::
Output
>
(
infer_engine
,
"Output"
)
.
def_readwrite
(
"output_ids"
,
&
InferEngine
::
Output
::
output_ids
,
"Output tensor"
);
...
...
examples/bench.py
View file @
c1a3ab29
...
...
@@ -3,6 +3,7 @@ from transformers import AutoTokenizer
from
infinilm.modeling_utils
import
load_model_state_dict_by_file
from
infinilm.distributed
import
DistConfig
from
infinilm.infer_engine
import
GenerationConfig
,
InferEngine
from
infinilm.cache
import
StaticKVCacheConfig
import
argparse
import
sys
import
time
...
...
@@ -260,6 +261,7 @@ class TestModel:
output_ids
=
self
.
model
.
generate
(
input_ids_infini
,
GenerationConfig
(
max_new_tokens
=
output_len
,
eos_token_id
=
[]),
_measure_and_log_time
=
True
,
)
t2
=
time
.
time
()
...
...
@@ -336,7 +338,11 @@ if __name__ == "__main__":
# reset cache for each case
initial_capacity
=
input_len
+
output_len
test
.
model
.
reset_cache
(
batch_size
=
batch_size
,
initial_capacity
=
initial_capacity
)
test
.
model
.
reset_cache
(
StaticKVCacheConfig
(
max_batch_size
=
batch_size
,
max_cache_len
=
initial_capacity
)
)
# run test one case
test
.
run
(
...
...
examples/jiuge.py
View file @
c1a3ab29
...
...
@@ -9,6 +9,7 @@ import sys
import
time
import
os
import
numpy
as
np
from
infinilm.cache
import
StaticKVCacheConfig
,
PagedKVCacheConfig
sys
.
path
.
insert
(
0
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"../python"
))
...
...
@@ -82,6 +83,11 @@ def get_args():
default
=
1
,
help
=
"total rank for tensor parallel"
,
)
parser
.
add_argument
(
"--enable-paged-attn"
,
action
=
"store_true"
,
help
=
"use paged cache"
,
)
return
parser
.
parse_args
()
...
...
@@ -92,10 +98,11 @@ def test(
max_new_tokens
=
100
,
infini_device
=
infinicore
.
device
(
"cpu"
,
0
),
tp
=
1
,
enable_paged_attn
=
False
,
):
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
#
创建模型,
#
Create Model
# ---------------------------------------------------------------------------- #
model
=
InferEngine
(
model_path
,
...
...
@@ -104,12 +111,12 @@ def test(
)
# ---------------------------------------------------------------------------- #
#
加载权重
#
Load Weights
# ---------------------------------------------------------------------------- #
load_model_state_dict_by_file
(
model
,
model_path
,
dtype
=
model
.
config
.
dtype
)
# ---------------------------------------------------------------------------- #
#
创建
tokenizer
#
create
tokenizer
# ---------------------------------------------------------------------------- #
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
...
...
@@ -132,7 +139,7 @@ def test(
)
# ---------------------------------------------------------------------------- #
# token
编码
# token
ize
# ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?"
if
isinstance
(
prompts
,
str
):
...
...
@@ -150,14 +157,26 @@ def test(
"input_ids"
]
# List: [[1, 1128, 526, 366, 29892]]
# 根据输入长度和最长输出长度创建KVCache
model
.
reset_cache
(
1
if
prompts
is
str
else
len
(
prompts
),
max_new_tokens
+
len
(
input_ids_list
[
0
]),
)
# ---------------------------------------------------------------------------- #
# Create KVCache
# ---------------------------------------------------------------------------- #
if
enable_paged_attn
:
batch_size
=
1
if
prompts
is
str
else
len
(
prompts
)
max_total_tokens
=
max_new_tokens
+
len
(
input_ids_list
[
0
])
cache_config
=
PagedKVCacheConfig
(
num_blocks
=
(
max_total_tokens
//
16
+
1
)
*
batch_size
,
block_size
=
16
)
else
:
batch_size
=
1
if
prompts
is
str
else
len
(
prompts
)
initial_capacity
=
max_new_tokens
+
len
(
input_ids_list
[
0
])
cache_config
=
StaticKVCacheConfig
(
max_batch_size
=
batch_size
,
max_cache_len
=
initial_capacity
)
model
.
reset_cache
(
cache_config
)
# ---------------------------------------------------------------------------- #
#
自回归生成
#
Generate
# ---------------------------------------------------------------------------- #
print
(
input_contents
[
0
],
end
=
""
,
flush
=
True
)
input_ids_infini
=
infinicore
.
from_list
(
input_ids_list
)
...
...
@@ -211,7 +230,7 @@ if __name__ == "__main__":
max_new_tokens
=
args
.
max_new_tokens
backend
=
args
.
backend
tp
=
args
.
tp
enable_paged_attn
=
args
.
enable_paged_attn
if
backend
!=
"cpp"
:
raise
ValueError
(
f
"Unsupported backend:
{
backend
}
."
)
...
...
@@ -223,4 +242,5 @@ if __name__ == "__main__":
max_new_tokens
,
infini_device
=
infini_device
,
tp
=
tp
,
enable_paged_attn
=
enable_paged_attn
,
)
python/infinilm/auto_config.py
View file @
c1a3ab29
...
...
@@ -21,5 +21,7 @@ class AutoConfig:
if
config_dict
[
"model_type"
]
==
"llama"
:
return
LlamaConfig
(
**
config_dict
)
elif
config_dict
[
"model_type"
]
==
"qwen2"
:
return
LlamaConfig
(
**
config_dict
)
raise
ValueError
(
f
"Unsupported model type `
{
config_dict
[
'model_type'
]
}
`."
)
python/infinilm/cache/__init__.py
View file @
c1a3ab29
from
.cache
import
CacheConfig
,
StaticKVCacheConfig
from
.cache
import
CacheConfig
,
StaticKVCacheConfig
,
PagedKVCacheConfig
__all__
=
[
"CacheConfig"
,
"StaticKVCacheConfig"
]
__all__
=
[
"CacheConfig"
,
"StaticKVCacheConfig"
,
"PagedKVCacheConfig"
]
python/infinilm/cache/cache.py
View file @
c1a3ab29
...
...
@@ -16,11 +16,11 @@ class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
class
PagedKVCacheConfig
(
CacheConfig
,
_infinilm
.
PagedKVCacheConfig
):
def
__init__
(
self
,
max_kv_memory_byte
s
:
int
,
num_block
s
:
int
,
block_size
:
int
=
16
,
):
_infinilm
.
PagedKVCacheConfig
.
__init__
(
self
,
max_kv_memory_byte
s
,
num_block
s
,
block_size
,
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment