Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
831e8a67
Commit
831e8a67
authored
Jan 08, 2026
by
PanZezhong
Browse files
issue/168 support fixed paged attention api
parent
e48b5b0d
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
144 additions
and
108 deletions
+144
-108
csrc/cache/kv_cache.cpp
csrc/cache/kv_cache.cpp
+2
-2
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+1
-1
csrc/engine/infer_engine.cpp
csrc/engine/infer_engine.cpp
+13
-34
csrc/engine/rank_worker.hpp
csrc/engine/rank_worker.hpp
+3
-1
csrc/models/infinilm_model.hpp
csrc/models/infinilm_model.hpp
+3
-1
csrc/models/llama/llama_attention.cpp
csrc/models/llama/llama_attention.cpp
+11
-9
csrc/models/llama/llama_attention.hpp
csrc/models/llama/llama_attention.hpp
+5
-3
csrc/models/llama/llama_decoder_layer.cpp
csrc/models/llama/llama_decoder_layer.cpp
+3
-2
csrc/models/llama/llama_decoder_layer.hpp
csrc/models/llama/llama_decoder_layer.hpp
+2
-1
csrc/models/llama/llama_for_causal_lm.cpp
csrc/models/llama/llama_for_causal_lm.cpp
+4
-2
csrc/models/llama/llama_model.cpp
csrc/models/llama/llama_model.cpp
+3
-2
csrc/models/llama/llama_model.hpp
csrc/models/llama/llama_model.hpp
+4
-2
csrc/pybind11/engine/engine.hpp
csrc/pybind11/engine/engine.hpp
+40
-15
python/infinilm/infer_engine.py
python/infinilm/infer_engine.py
+50
-33
No files found.
csrc/cache/kv_cache.cpp
View file @
831e8a67
...
@@ -80,12 +80,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor>
...
@@ -80,12 +80,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor>
StaticKVCache
::
update
(
size_t
layer_idx
,
StaticKVCache
::
update
(
size_t
layer_idx
,
const
infinicore
::
Tensor
&
k
,
const
infinicore
::
Tensor
&
k
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
cach
e_lengths
)
{
const
infinicore
::
Tensor
&
past_sequenc
e_lengths
)
{
ASSERT
(
layer_idx
<
rank_num_layers_
);
ASSERT
(
layer_idx
<
rank_num_layers_
);
auto
batch_size
=
k
->
size
(
0
);
auto
batch_size
=
k
->
size
(
0
);
auto
update_len
=
k
->
size
(
2
);
auto
update_len
=
k
->
size
(
2
);
size_t
cache_pos
=
reinterpret_cast
<
int64_t
*>
(
cach
e_lengths
->
to
(
infinicore
::
Device
::
cpu
())
->
data
())[
0
];
size_t
cache_pos
=
reinterpret_cast
<
int64_t
*>
(
past_sequenc
e_lengths
->
to
(
infinicore
::
Device
::
cpu
())
->
data
())[
0
];
auto
result_len
=
cache_pos
+
update_len
;
auto
result_len
=
cache_pos
+
update_len
;
ASSERT
(
result_len
<=
cache_len_
);
ASSERT
(
result_len
<=
cache_len_
);
...
...
csrc/cache/kv_cache.hpp
View file @
831e8a67
...
@@ -61,7 +61,7 @@ public:
...
@@ -61,7 +61,7 @@ public:
update
(
size_t
layer_idx
,
update
(
size_t
layer_idx
,
const
infinicore
::
Tensor
&
k
,
const
infinicore
::
Tensor
&
k
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
cach
e_lengths
);
const
infinicore
::
Tensor
&
past_sequenc
e_lengths
);
~
StaticKVCache
()
override
=
default
;
~
StaticKVCache
()
override
=
default
;
...
...
csrc/engine/infer_engine.cpp
View file @
831e8a67
...
@@ -56,44 +56,23 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
...
@@ -56,44 +56,23 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
//------------------------------------------------------
//------------------------------------------------------
// forward
// forward
//------------------------------------------------------
//------------------------------------------------------
infinilm
::
InfinilmModel
::
Input
InferEngine
::
Input
::
to_model_input
(
infinicore
::
Device
device
)
const
{
infinilm
::
InfinilmModel
::
Input
InferEngine
::
Input
::
to_model_input
(
infinicore
::
Device
device
)
const
{
std
::
optional
<
infinicore
::
Tensor
>
position_ids_on_device
;
auto
to_device
=
[
&
](
const
std
::
optional
<
infinicore
::
Tensor
>
&
t
)
if
(
position_ids
.
has_value
())
{
->
std
::
optional
<
infinicore
::
Tensor
>
{
position_ids_on_device
=
position_ids
.
value
()
->
to
(
device
);
return
t
.
has_value
()
?
t
.
value
()
->
to
(
device
)
:
t
;
}
};
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths_on_device
;
if
(
cache_lengths
.
has_value
())
{
if
(
block_tables
.
has_value
())
{
cache_lengths_on_device
=
cache_lengths
.
value
()
->
to
(
device
);
}
else
{
// @todo: only paged kv cache support device tensor so far
cache_lengths_on_device
=
cache_lengths
.
value
();
}
}
std
::
optional
<
infinicore
::
Tensor
>
input_offsets_on_device
;
if
(
input_offsets
.
has_value
())
{
input_offsets_on_device
=
input_offsets
.
value
()
->
to
(
device
);
}
std
::
optional
<
infinicore
::
Tensor
>
block_tables_on_device
;
if
(
block_tables
.
has_value
())
{
block_tables_on_device
=
block_tables
.
value
()
->
to
(
device
);
}
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping_on_device
;
if
(
slot_mapping
.
has_value
())
{
slot_mapping_on_device
=
slot_mapping
.
value
()
->
to
(
device
);
}
return
{
return
{
input_ids
,
// @todo: on device in the future
input_ids
,
// @todo: on device in the future
position_ids_on_device
,
to_device
(
position_ids
),
cache_lengths_on_device
,
past_sequence_lengths
,
// @todo: on device in the future
input_offsets_on_device
,
to_device
(
total_sequence_lengths
),
block_tables_on_device
,
to_device
(
input_offsets
),
slot_mapping_on_device
};
to_device
(
block_tables
),
to_device
(
slot_mapping
),
};
}
}
InferEngine
::
Output
InferEngine
::
forward
(
const
InferEngine
::
Input
&
input
)
{
InferEngine
::
Output
InferEngine
::
forward
(
const
InferEngine
::
Input
&
input
)
{
...
...
csrc/engine/rank_worker.hpp
View file @
831e8a67
...
@@ -29,7 +29,9 @@ public:
...
@@ -29,7 +29,9 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std
::
optional
<
infinicore
::
Tensor
>
position_ids
;
std
::
optional
<
infinicore
::
Tensor
>
position_ids
;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
;
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
;
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
...
...
csrc/models/infinilm_model.hpp
View file @
831e8a67
...
@@ -23,7 +23,9 @@ public:
...
@@ -23,7 +23,9 @@ public:
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
std
::
optional
<
infinicore
::
Tensor
>
position_ids
;
std
::
optional
<
infinicore
::
Tensor
>
position_ids
;
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
/// Past Lengths of cached sequence for each request, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
;
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
;
/// ToTal Lengths for each request sequence, of shape `[num_requests]`.
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
;
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
/// Offsets of each request in a continous-batched sequence, of shape `[num_requests + 1]`.
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
;
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
;
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
/// Block ids for each request `[batch, max_block_table_length]`. Used for paged cache.
...
...
csrc/models/llama/llama_attention.cpp
View file @
831e8a67
...
@@ -57,7 +57,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
...
@@ -57,7 +57,8 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
infinicore
::
Tensor
LlamaAttention
::
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
LlamaAttention
::
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
)
const
{
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
)
const
{
// Input shape: [batch, seq_len, hidden_size]
// Input shape: [batch, seq_len, hidden_size]
auto
hidden_states_mutable
=
hidden_states
;
auto
hidden_states_mutable
=
hidden_states
;
auto
shape
=
hidden_states
->
shape
();
auto
shape
=
hidden_states
->
shape
();
...
@@ -105,7 +106,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
...
@@ -105,7 +106,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
k_total
=
k_permuted
;
k_total
=
k_permuted
;
v_total
=
v_permuted
;
v_total
=
v_permuted
;
}
else
if
(
auto
static_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
StaticKVCache
>
(
kv_cache
))
{
}
else
if
(
auto
static_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
StaticKVCache
>
(
kv_cache
))
{
auto
[
k_total_tmp
,
v_total_tmp
]
=
static_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
cach
e_lengths
.
value
());
auto
[
k_total_tmp
,
v_total_tmp
]
=
static_kv_cache
->
update
(
layer_idx_
,
k_permuted
,
v_permuted
,
past_sequenc
e_lengths
.
value
());
k_total
=
k_total_tmp
;
k_total
=
k_total_tmp
;
v_total
=
v_total_tmp
;
v_total
=
v_total_tmp
;
}
else
{
}
else
{
...
@@ -141,7 +142,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
...
@@ -141,7 +142,7 @@ infinicore::Tensor LlamaAttention::forward_(const infinicore::Tensor &hidden_sta
infinicore
::
Tensor
LlamaAttention
::
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
LlamaAttention
::
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
paged_kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
paged_kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
...
@@ -157,7 +158,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
...
@@ -157,7 +158,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
// Only support batchsize==1, all requests should be flattened along seqlen dimension
// Only support batchsize==1, all requests should be flattened along seqlen dimension
ASSERT_EQ
(
batch_size
,
1
);
ASSERT_EQ
(
batch_size
,
1
);
// Decode only if total_len == num_requests
// Decode only if total_len == num_requests
bool
is_prefill
=
(
seq_len
!=
cach
e_lengths
.
value
()
->
shape
()[
0
]);
bool
is_prefill
=
(
seq_len
!=
total_sequenc
e_lengths
.
value
()
->
shape
()[
0
]);
// 1. Project Q, K, V
// 1. Project Q, K, V
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
auto
[
q
,
k
,
v
]
=
qkv_proj_
->
forward_split
(
hidden_states_mutable
);
...
@@ -204,7 +205,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
...
@@ -204,7 +205,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
k_total
,
k_total
,
v_total
,
v_total
,
block_tables
.
value
(),
block_tables
.
value
(),
cach
e_lengths
.
value
(),
total_sequenc
e_lengths
.
value
(),
input_offsets
.
value
(),
input_offsets
.
value
(),
std
::
nullopt
,
std
::
nullopt
,
scaling_
);
scaling_
);
...
@@ -216,7 +217,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
...
@@ -216,7 +217,7 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
k_total
,
k_total
,
v_total
,
v_total
,
block_tables
.
value
(),
block_tables
.
value
(),
cach
e_lengths
.
value
(),
total_sequenc
e_lengths
.
value
(),
std
::
nullopt
,
std
::
nullopt
,
scaling_
);
scaling_
);
}
}
...
@@ -229,7 +230,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
...
@@ -229,7 +230,8 @@ infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidd
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
LlamaAttention
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache
,
std
::
shared_ptr
<
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
...
@@ -239,10 +241,10 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
...
@@ -239,10 +241,10 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
infinicore
::
Tensor
output
;
infinicore
::
Tensor
output
;
if
(
auto
paged_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
PagedKVCache
>
(
kv_cache
))
{
if
(
auto
paged_kv_cache
=
std
::
dynamic_pointer_cast
<
cache
::
PagedKVCache
>
(
kv_cache
))
{
output
=
forward_paged_
(
hidden_states
,
position_ids
,
paged_kv_cache
,
cach
e_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
output
=
forward_paged_
(
hidden_states
,
position_ids
,
paged_kv_cache
,
total_sequenc
e_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
}
else
{
}
else
{
output
=
forward_
(
hidden_states
,
position_ids
,
kv_cache
,
cach
e_lengths
);
output
=
forward_
(
hidden_states
,
position_ids
,
kv_cache
,
past_sequence_lengths
,
total_sequenc
e_lengths
);
}
}
return
output
;
return
output
;
}
}
...
...
csrc/models/llama/llama_attention.hpp
View file @
831e8a67
...
@@ -51,7 +51,8 @@ public:
...
@@ -51,7 +51,8 @@ public:
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
...
@@ -76,12 +77,13 @@ private:
...
@@ -76,12 +77,13 @@ private:
infinicore
::
Tensor
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
forward_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
)
const
;
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
)
const
;
infinicore
::
Tensor
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
forward_paged_
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cach
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequenc
e_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
...
...
csrc/models/llama/llama_decoder_layer.cpp
View file @
831e8a67
...
@@ -26,7 +26,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
...
@@ -26,7 +26,8 @@ LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
infinicore
::
Tensor
LlamaDecoderLayer
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
LlamaDecoderLayer
::
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
...
@@ -37,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
...
@@ -37,7 +38,7 @@ infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_s
auto
normed_states
=
input_layernorm_
->
forward
(
hidden_states
);
auto
normed_states
=
input_layernorm_
->
forward
(
hidden_states
);
// 2. Self-attention with residual connection
// 2. Self-attention with residual connection
auto
attn_output
=
self_attn_
->
forward
(
normed_states
,
position_ids
,
kv_cache
,
cach
e_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
auto
attn_output
=
self_attn_
->
forward
(
normed_states
,
position_ids
,
kv_cache
,
past_sequence_lengths
,
total_sequenc
e_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
// Add residual: hidden_states = hidden_states + attn_output
// Add residual: hidden_states = hidden_states + attn_output
auto
output
=
infinicore
::
op
::
add
(
residual
,
attn_output
);
auto
output
=
infinicore
::
op
::
add
(
residual
,
attn_output
);
...
...
csrc/models/llama/llama_decoder_layer.hpp
View file @
831e8a67
...
@@ -49,7 +49,8 @@ public:
...
@@ -49,7 +49,8 @@ public:
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
hidden_states
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
shared_ptr
<
infinilm
::
cache
::
Cache
>
kv_cache
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mappin
)
const
;
std
::
optional
<
infinicore
::
Tensor
>
slot_mappin
)
const
;
...
...
csrc/models/llama/llama_for_causal_lm.cpp
View file @
831e8a67
...
@@ -28,13 +28,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
...
@@ -28,13 +28,15 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
LlamaForCausalLM
::
Output
LlamaForCausalLM
::
forward
(
const
Input
&
input
)
const
{
LlamaForCausalLM
::
Output
LlamaForCausalLM
::
forward
(
const
Input
&
input
)
const
{
auto
input_ids
=
input
.
input_ids
.
value
();
auto
input_ids
=
input
.
input_ids
.
value
();
auto
position_ids
=
input
.
position_ids
.
value
();
auto
position_ids
=
input
.
position_ids
.
value
();
auto
cache_lengths
=
input
.
cache_lengths
;
auto
past_sequence_lengths
=
input
.
past_sequence_lengths
;
auto
total_sequence_length
=
input
.
total_sequence_lengths
;
auto
input_offsets
=
input
.
input_offsets
;
auto
input_offsets
=
input
.
input_offsets
;
auto
block_tables
=
input
.
block_tables
;
auto
block_tables
=
input
.
block_tables
;
auto
slot_mapping
=
input
.
slot_mapping
;
auto
slot_mapping
=
input
.
slot_mapping
;
// 1. Forward through base model to get hidden states
// 1. Forward through base model to get hidden states
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids
,
cache_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
auto
hidden_states
=
model_
->
forward
(
input_ids
,
position_ids
,
past_sequence_lengths
,
total_sequence_length
,
input_offsets
,
block_tables
,
slot_mapping
);
// 2. Apply language modeling head to get logits
// 2. Apply language modeling head to get logits
auto
logits
=
lm_head_
->
forward
(
hidden_states
);
auto
logits
=
lm_head_
->
forward
(
hidden_states
);
...
...
csrc/models/llama/llama_model.cpp
View file @
831e8a67
...
@@ -45,7 +45,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
...
@@ -45,7 +45,8 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
infinicore
::
Tensor
LlamaModel
::
forward
(
const
infinicore
::
Tensor
&
input_ids
,
infinicore
::
Tensor
LlamaModel
::
forward
(
const
infinicore
::
Tensor
&
input_ids
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
{
...
@@ -55,7 +56,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
...
@@ -55,7 +56,7 @@ infinicore::Tensor LlamaModel::forward(const infinicore::Tensor &input_ids,
// 2. Process through all decoder layers
// 2. Process through all decoder layers
size_t
num_layers
=
layers_
.
size
();
size_t
num_layers
=
layers_
.
size
();
for
(
size_t
i
=
0
;
i
<
num_layers
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
num_layers
;
++
i
)
{
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
kv_cache_
,
cach
e_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
hidden_states
=
layers_
.
at
(
i
)
->
forward
(
hidden_states
,
position_ids
,
kv_cache_
,
past_sequence_lengths
,
total_sequenc
e_lengths
,
input_offsets
,
block_tables
,
slot_mapping
);
}
}
return
norm_
->
forward
(
hidden_states
);
return
norm_
->
forward
(
hidden_states
);
...
...
csrc/models/llama/llama_model.hpp
View file @
831e8a67
...
@@ -48,13 +48,15 @@ public:
...
@@ -48,13 +48,15 @@ public:
* @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used,
* @param input_ids Token IDs tensor of shape [batch, seq_len]. Batch is 1 when continuous batch is used,
* and tokens from all requests are concatenated along seq_len dimension.
* and tokens from all requests are concatenated along seq_len dimension.
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param cache_lengths Cache positions tensor of shape [n_req]
* @param past_sequence_lengths Cache positions tensor of shape [n_req]
* @param total_sequence_lengths Total sequence lengths tensor of shape [n_req]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req + 1]
* @param input_offsets Input offsets (starting position) of each request in a continuous batch of shape [n_req + 1]
* @return Output tensor of shape [batch, seq_len, hidden_size]
* @return Output tensor of shape [batch, seq_len, hidden_size]
*/
*/
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
input_ids
,
infinicore
::
Tensor
forward
(
const
infinicore
::
Tensor
&
input_ids
,
const
infinicore
::
Tensor
&
position_ids
,
const
infinicore
::
Tensor
&
position_ids
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
)
const
;
...
...
csrc/pybind11/engine/engine.hpp
View file @
831e8a67
...
@@ -80,28 +80,48 @@ inline void bind_infer_engine(py::module &m) {
...
@@ -80,28 +80,48 @@ inline void bind_infer_engine(py::module &m) {
py
::
init
([](
py
::
init
([](
std
::
optional
<
infinicore
::
Tensor
>
input_ids
,
std
::
optional
<
infinicore
::
Tensor
>
input_ids
,
std
::
optional
<
infinicore
::
Tensor
>
position_ids
,
std
::
optional
<
infinicore
::
Tensor
>
position_ids
,
std
::
optional
<
infinicore
::
Tensor
>
cache_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
past_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
total_sequence_lengths
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
input_offsets
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
block_tables
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
,
std
::
optional
<
infinicore
::
Tensor
>
slot_mapping
,
py
::
kwargs
kwargs
)
{
py
::
kwargs
kwargs
)
{
auto
input
{
InferEngine
::
Input
{
InferEngine
::
Input
input
{
std
::
move
(
input_ids
),
std
::
move
(
input_ids
),
std
::
move
(
position_ids
),
std
::
move
(
position_ids
),
std
::
move
(
cache_lengths
),
std
::
move
(
past_sequence_lengths
),
std
::
move
(
total_sequence_lengths
),
std
::
move
(
input_offsets
),
std
::
move
(
input_offsets
),
std
::
move
(
block_tables
),
std
::
move
(
block_tables
),
std
::
move
(
slot_mapping
)}};
std
::
move
(
slot_mapping
),
};
if
(
kwargs
)
{
// Explicit defaults
if
(
kwargs
.
contains
(
"temperature"
))
{
input
.
temperature
=
1.0
f
;
input
.
temperature
=
kwargs
[
"temperature"
].
cast
<
float
>
();
input
.
top_p
=
1.0
f
;
}
input
.
top_k
=
1
;
if
(
kwargs
.
contains
(
"top_k"
))
{
input
.
top_k
=
kwargs
[
"top_k"
].
cast
<
int
>
();
// Allowed keyword arguments
static
const
std
::
unordered_set
<
std
::
string
>
allowed_kwargs
=
{
"temperature"
,
"top_p"
,
"top_k"
,
};
for
(
auto
&
item
:
kwargs
)
{
const
std
::
string
key
=
py
::
cast
<
std
::
string
>
(
item
.
first
);
if
(
allowed_kwargs
.
find
(
key
)
==
allowed_kwargs
.
end
())
{
throw
py
::
value_error
(
"InferEngine.Input got an unexpected keyword argument '"
+
key
+
"'"
);
}
}
if
(
kwargs
.
contains
(
"top_p"
))
{
input
.
top_p
=
kwargs
[
"top_p"
].
cast
<
float
>
();
if
(
key
==
"temperature"
)
{
input
.
temperature
=
py
::
cast
<
float
>
(
item
.
second
);
}
else
if
(
key
==
"top_p"
)
{
input
.
top_p
=
py
::
cast
<
float
>
(
item
.
second
);
}
else
if
(
key
==
"top_k"
)
{
input
.
top_k
=
py
::
cast
<
int
>
(
item
.
second
);
}
}
}
}
...
@@ -109,16 +129,21 @@ inline void bind_infer_engine(py::module &m) {
...
@@ -109,16 +129,21 @@ inline void bind_infer_engine(py::module &m) {
}),
}),
py
::
arg
(
"input_ids"
)
=
std
::
nullopt
,
py
::
arg
(
"input_ids"
)
=
std
::
nullopt
,
py
::
arg
(
"position_ids"
)
=
std
::
nullopt
,
py
::
arg
(
"position_ids"
)
=
std
::
nullopt
,
py
::
arg
(
"cache_lengths"
)
=
std
::
nullopt
,
py
::
arg
(
"past_sequence_lengths"
)
=
std
::
nullopt
,
py
::
arg
(
"total_sequence_lengths"
)
=
std
::
nullopt
,
py
::
arg
(
"input_offsets"
)
=
std
::
nullopt
,
py
::
arg
(
"input_offsets"
)
=
std
::
nullopt
,
py
::
arg
(
"block_tables"
)
=
std
::
nullopt
,
py
::
arg
(
"block_tables"
)
=
std
::
nullopt
,
py
::
arg
(
"slot_mapping"
)
=
std
::
nullopt
)
py
::
arg
(
"slot_mapping"
)
=
std
::
nullopt
)
.
def_readwrite
(
"input_ids"
,
&
InferEngine
::
Input
::
input_ids
)
.
def_readwrite
(
"input_ids"
,
&
InferEngine
::
Input
::
input_ids
)
.
def_readwrite
(
"position_ids"
,
&
InferEngine
::
Input
::
position_ids
)
.
def_readwrite
(
"position_ids"
,
&
InferEngine
::
Input
::
position_ids
)
.
def_readwrite
(
"cache_lengths"
,
&
InferEngine
::
Input
::
cache_lengths
)
.
def_readwrite
(
"past_sequence_lengths"
,
&
InferEngine
::
Input
::
past_sequence_lengths
)
.
def_readwrite
(
"total_sequence_lengths"
,
&
InferEngine
::
Input
::
total_sequence_lengths
)
.
def_readwrite
(
"input_offsets"
,
&
InferEngine
::
Input
::
input_offsets
)
.
def_readwrite
(
"input_offsets"
,
&
InferEngine
::
Input
::
input_offsets
)
.
def_readwrite
(
"block_tables"
,
&
InferEngine
::
Input
::
block_tables
)
.
def_readwrite
(
"block_tables"
,
&
InferEngine
::
Input
::
block_tables
)
.
def_readwrite
(
"slot_mapping"
,
&
InferEngine
::
Input
::
slot_mapping
);
.
def_readwrite
(
"slot_mapping"
,
&
InferEngine
::
Input
::
slot_mapping
)
.
def_readwrite
(
"temperature"
,
&
InferEngine
::
Input
::
temperature
)
.
def_readwrite
(
"top_k"
,
&
InferEngine
::
Input
::
top_k
)
.
def_readwrite
(
"top_p"
,
&
InferEngine
::
Input
::
top_p
);
py
::
class_
<
InferEngine
::
Output
>
(
infer_engine
,
"Output"
)
py
::
class_
<
InferEngine
::
Output
>
(
infer_engine
,
"Output"
)
.
def_readwrite
(
"output_ids"
,
&
InferEngine
::
Output
::
output_ids
,
"Output tensor"
);
.
def_readwrite
(
"output_ids"
,
&
InferEngine
::
Output
::
output_ids
,
"Output tensor"
);
...
...
python/infinilm/infer_engine.py
View file @
831e8a67
...
@@ -53,7 +53,8 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -53,7 +53,8 @@ class InferEngine(_infinilm.InferEngine):
input_ids
,
input_ids
,
*
,
*
,
position_ids
=
None
,
position_ids
=
None
,
cache_lengths
=
None
,
past_kv_lengths
=
None
,
total_kv_lengths
=
None
,
input_offsets
=
None
,
input_offsets
=
None
,
block_tables
=
None
,
block_tables
=
None
,
slot_mapping
=
None
,
slot_mapping
=
None
,
...
@@ -64,7 +65,12 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -64,7 +65,12 @@ class InferEngine(_infinilm.InferEngine):
# TODO: Remove `_underlying` and simplify the corresponding code.
# TODO: Remove `_underlying` and simplify the corresponding code.
input_ids
=
input_ids
.
_underlying
if
input_ids
is
not
None
else
None
input_ids
=
input_ids
.
_underlying
if
input_ids
is
not
None
else
None
position_ids
=
position_ids
.
_underlying
if
position_ids
is
not
None
else
None
position_ids
=
position_ids
.
_underlying
if
position_ids
is
not
None
else
None
cache_lengths
=
cache_lengths
.
_underlying
if
cache_lengths
is
not
None
else
None
past_kv_lengths
=
(
past_kv_lengths
.
_underlying
if
past_kv_lengths
is
not
None
else
None
)
total_kv_lengths
=
(
total_kv_lengths
.
_underlying
if
past_kv_lengths
is
not
None
else
None
)
input_offsets
=
input_offsets
.
_underlying
if
input_offsets
is
not
None
else
None
input_offsets
=
input_offsets
.
_underlying
if
input_offsets
is
not
None
else
None
block_tables
=
block_tables
.
_underlying
if
block_tables
is
not
None
else
None
block_tables
=
block_tables
.
_underlying
if
block_tables
is
not
None
else
None
slot_mapping
=
slot_mapping
.
_underlying
if
slot_mapping
is
not
None
else
None
slot_mapping
=
slot_mapping
.
_underlying
if
slot_mapping
is
not
None
else
None
...
@@ -75,7 +81,8 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -75,7 +81,8 @@ class InferEngine(_infinilm.InferEngine):
super
().
Input
(
super
().
Input
(
input_ids
,
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cache_lengths
=
cache_lengths
,
past_sequence_lengths
=
past_kv_lengths
,
total_sequence_lengths
=
total_kv_lengths
,
input_offsets
=
input_offsets
,
input_offsets
=
input_offsets
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
...
@@ -87,7 +94,14 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -87,7 +94,14 @@ class InferEngine(_infinilm.InferEngine):
.
output_ids
.
output_ids
)
)
def
generate
(
self
,
input_ids
,
generation_config
,
*
,
_measure_and_log_time
=
False
):
def
generate
(
self
,
input_ids
,
generation_config
,
*
,
_measure_and_log_time
=
False
,
paged_block_size
=
16
,
):
if
generation_config
.
eos_token_id
is
None
:
if
generation_config
.
eos_token_id
is
None
:
eos_token_id
=
self
.
config
.
eos_token_id
eos_token_id
=
self
.
config
.
eos_token_id
else
:
else
:
...
@@ -119,31 +133,30 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -119,31 +133,30 @@ class InferEngine(_infinilm.InferEngine):
list
(
range
(
past_seq_len
,
past_seq_len
+
seq_len
))
*
batch_size
,
list
(
range
(
past_seq_len
,
past_seq_len
+
seq_len
))
*
batch_size
,
dtype
=
infinicore
.
int64
,
dtype
=
infinicore
.
int64
,
)
)
cache_lengths
=
infinicore
.
from_list
(
block_tables_list
=
[
[
past_seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
[
)
i
*
batch_size
+
b
for
i
in
range
(
(
past_seq_len
+
seq_len
+
paged_block_size
-
1
)
//
paged_block_size
)
]
for
b
in
range
(
batch_size
)
]
slot_mapping_list
=
[
(((
past_seq_len
+
i
)
//
paged_block_size
)
*
batch_size
+
b
)
*
paged_block_size
+
(
past_seq_len
+
i
)
%
paged_block_size
for
b
in
range
(
batch_size
)
for
i
in
range
(
seq_len
)
]
input_offsets
=
infinicore
.
from_list
(
[
seq_len
*
i
for
i
in
range
(
batch_size
+
1
)],
dtype
=
infinicore
.
int64
)
block_tables
=
infinicore
.
from_list
(
block_tables
=
infinicore
.
from_list
(
[
block_tables_list
,
[
i
*
batch_size
+
b
for
i
in
range
((
past_seq_len
+
seq_len
+
15
)
//
16
)
]
for
b
in
range
(
batch_size
)
],
dtype
=
infinicore
.
int64
,
dtype
=
infinicore
.
int64
,
)
)
slot_mapping
=
infinicore
.
from_list
(
slot_mapping
=
infinicore
.
from_list
(
[
slot_mapping_list
,
((
past_seq_len
+
i
+
15
)
//
16
)
*
batch_size
+
b
+
(
past_seq_len
+
i
+
15
)
%
16
for
i
in
range
(
seq_len
)
for
b
in
range
(
batch_size
)
],
dtype
=
infinicore
.
int64
,
dtype
=
infinicore
.
int64
,
)
)
else
:
else
:
...
@@ -155,21 +168,25 @@ class InferEngine(_infinilm.InferEngine):
...
@@ -155,21 +168,25 @@ class InferEngine(_infinilm.InferEngine):
dtype
=
infinicore
.
int64
,
dtype
=
infinicore
.
int64
,
)
)
cache_lengths
=
infinicore
.
from_list
(
[
past_seq_len
],
dtype
=
infinicore
.
int64
)
input_offsets
=
infinicore
.
from_list
(
[
seq_len
*
i
for
i
in
range
(
batch_size
+
1
)],
dtype
=
infinicore
.
int64
)
block_tables
=
None
block_tables
=
None
slot_mapping
=
None
slot_mapping
=
None
past_kv_lengths
=
infinicore
.
from_list
(
[
past_seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
)
total_kv_lengths
=
infinicore
.
from_list
(
[
past_seq_len
+
seq_len
]
*
batch_size
,
dtype
=
infinicore
.
int64
)
input_offsets
=
infinicore
.
from_list
(
[
seq_len
*
i
for
i
in
range
(
batch_size
+
1
)],
dtype
=
infinicore
.
int64
)
output_id
=
self
(
output_id
=
self
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
position_ids
=
position_ids
,
position_ids
=
position_ids
,
cache_lengths
=
cache_lengths
,
past_kv_lengths
=
past_kv_lengths
,
total_kv_lengths
=
total_kv_lengths
,
input_offsets
=
input_offsets
,
input_offsets
=
input_offsets
,
block_tables
=
block_tables
,
block_tables
=
block_tables
,
slot_mapping
=
slot_mapping
,
slot_mapping
=
slot_mapping
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment