Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
21274f33
Commit
21274f33
authored
Jan 30, 2026
by
PanZezhong
Committed by
wooway777
Feb 10, 2026
Browse files
issue/143 feat: static and paged graph compilers
parent
96ecf490
Changes
19
Hide whitespace changes
Inline
Side-by-side
Showing
19 changed files
with
376 additions
and
22 deletions
+376
-22
.gitignore
.gitignore
+2
-0
csrc/engine/compiler/general_compiler.cpp
csrc/engine/compiler/general_compiler.cpp
+26
-0
csrc/engine/compiler/general_compiler.hpp
csrc/engine/compiler/general_compiler.hpp
+19
-0
csrc/engine/compiler/graph_compiler.hpp
csrc/engine/compiler/graph_compiler.hpp
+23
-0
csrc/engine/compiler/paged_compiler.cpp
csrc/engine/compiler/paged_compiler.cpp
+83
-0
csrc/engine/compiler/paged_compiler.hpp
csrc/engine/compiler/paged_compiler.hpp
+31
-0
csrc/engine/compiler/static_batching_compiler.cpp
csrc/engine/compiler/static_batching_compiler.cpp
+51
-0
csrc/engine/compiler/static_batching_compiler.hpp
csrc/engine/compiler/static_batching_compiler.hpp
+36
-0
csrc/engine/infer_engine.cpp
csrc/engine/infer_engine.cpp
+4
-2
csrc/engine/infer_engine.hpp
csrc/engine/infer_engine.hpp
+2
-1
csrc/engine/rank_worker.cpp
csrc/engine/rank_worker.cpp
+26
-4
csrc/engine/rank_worker.hpp
csrc/engine/rank_worker.hpp
+7
-1
csrc/models/infinilm_model.hpp
csrc/models/infinilm_model.hpp
+1
-0
csrc/models/llama/llama_for_causal_lm.cpp
csrc/models/llama/llama_for_causal_lm.cpp
+6
-1
csrc/models/llama/llama_for_causal_lm.hpp
csrc/models/llama/llama_for_causal_lm.hpp
+4
-0
csrc/pybind11/engine/engine.hpp
csrc/pybind11/engine/engine.hpp
+6
-3
examples/bench.py
examples/bench.py
+37
-9
examples/jiuge.py
examples/jiuge.py
+10
-1
python/infinilm/infer_engine.py
python/infinilm/infer_engine.py
+2
-0
No files found.
.gitignore
View file @
21274f33
...
...
@@ -29,3 +29,5 @@ __pycache__/
*.txt
*.http
*.nsys-rep
csrc/engine/compiler/general_compiler.cpp
0 → 100644
View file @
21274f33
#include "general_compiler.hpp"
namespace
infinilm
::
engine
{
GeneralCompiler
::
GeneralCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
GraphCompiler
(
model
)
{
static_batching_compiler_
=
std
::
make_unique
<
StaticBatchingCompiler
>
(
model_
);
paged_compiler_
=
std
::
make_unique
<
PagedCompiler
>
(
model_
);
}
void
GeneralCompiler
::
compile
()
{
static_batching_compiler_
->
compile
();
paged_compiler_
->
compile
();
}
GeneralCompiler
::
Compiled
GeneralCompiler
::
get_compiled
(
const
InfinilmModel
::
Input
&
input
)
{
GeneralCompiler
::
Compiled
result
=
{
nullptr
,
nullptr
};
// try each compiler, return the first valid result
result
=
static_batching_compiler_
.
get
()
->
get_compiled
(
input
);
if
(
std
::
get
<
0
>
(
result
)
!=
nullptr
&&
std
::
get
<
1
>
(
result
)
!=
nullptr
)
{
return
result
;
}
result
=
paged_compiler_
.
get
()
->
get_compiled
(
input
);
return
result
;
}
}
// namespace infinilm::engine
csrc/engine/compiler/general_compiler.hpp
0 → 100644
View file @
21274f33
#pragma once
#include "paged_compiler.hpp"
#include "static_batching_compiler.hpp"
namespace
infinilm
::
engine
{
class
GeneralCompiler
:
public
GraphCompiler
{
public:
GeneralCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
);
void
compile
()
override
;
Compiled
get_compiled
(
const
InfinilmModel
::
Input
&
input
)
override
;
private:
std
::
unique_ptr
<
StaticBatchingCompiler
>
static_batching_compiler_
;
std
::
unique_ptr
<
PagedCompiler
>
paged_compiler_
;
};
}
// namespace infinilm::engine
csrc/engine/compiler/graph_compiler.hpp
0 → 100644
View file @
21274f33
#pragma once
#include "../../models/infinilm_model.hpp"
namespace
infinilm
::
engine
{
class
GraphCompiler
{
public:
using
Compiled
=
std
::
tuple
<
std
::
shared_ptr
<
infinicore
::
graph
::
Graph
>
,
std
::
shared_ptr
<
InfinilmModel
::
Output
>>
;
explicit
GraphCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
model_
(
model
)
{}
virtual
~
GraphCompiler
()
=
default
;
virtual
void
compile
()
=
0
;
virtual
Compiled
get_compiled
(
const
InfinilmModel
::
Input
&
input
)
=
0
;
protected:
std
::
shared_ptr
<
InfinilmModel
>
model_
;
};
}
// namespace infinilm::engine
csrc/engine/compiler/paged_compiler.cpp
0 → 100644
View file @
21274f33
#include "paged_compiler.hpp"
namespace
infinilm
::
engine
{
PagedCompiler
::
PagedCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
GraphCompiler
(
model
)
{
for
(
size_t
b
=
1
;
b
<
32
;
b
++
)
{
decode_batch_sizes_
.
push_back
(
b
);
}
for
(
size_t
b
=
32
;
b
<
64
;
b
+=
8
)
{
decode_batch_sizes_
.
push_back
(
b
);
}
for
(
size_t
b
=
64
;
b
<
128
;
b
+=
16
)
{
decode_batch_sizes_
.
push_back
(
b
);
}
for
(
size_t
b
=
128
;
b
<
256
;
b
+=
32
)
{
decode_batch_sizes_
.
push_back
(
b
);
}
for
(
size_t
b
=
256
;
b
<=
512
;
b
+=
64
)
{
decode_batch_sizes_
.
push_back
(
b
);
}
}
void
PagedCompiler
::
compile
()
{
if
(
model_
->
get_cache_config
()
!=
nullptr
&&
dynamic_cast
<
const
cache
::
PagedKVCacheConfig
*>
(
model_
->
get_cache_config
()))
{
size_t
nblocks
=
dynamic_cast
<
const
cache
::
PagedKVCacheConfig
*>
(
model_
->
get_cache_config
())
->
num_blocks
();
size_t
max_batch_size
=
*
std
::
max_element
(
decode_batch_sizes_
.
begin
(),
decode_batch_sizes_
.
end
());
compiled_map_decode_
.
clear
();
block_tables_holder_
=
infinicore
::
Tensor
::
empty
(
{
nblocks
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
for
(
size_t
b
:
decode_batch_sizes_
)
{
size_t
block_per_req
=
nblocks
/
b
;
InfinilmModel
::
Input
input
;
input
.
input_ids
=
infinicore
::
Tensor
::
empty
({
1
,
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
input
.
position_ids
=
infinicore
::
Tensor
::
empty
({
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
input
.
total_sequence_lengths
=
infinicore
::
Tensor
::
empty
({
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
input
.
input_offsets
=
infinicore
::
Tensor
::
empty
({
b
+
1
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
input
.
block_tables
=
block_tables_holder_
->
as_strided
({
b
,
block_per_req
},
{(
ptrdiff_t
)
block_per_req
,
1
});
input
.
slot_mapping
=
infinicore
::
Tensor
::
empty
({
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
infinicore
::
context
::
startGraphRecording
();
auto
output
=
model_
->
forward
(
input
);
auto
graph
=
infinicore
::
context
::
stopGraphRecording
();
auto
shared_output
=
std
::
shared_ptr
<
InfinilmModel
::
Output
>
(
new
InfinilmModel
::
Output
{
infinicore
::
graph
::
GraphTensor
(
output
.
logits
)});
compiled_map_decode_
[
b
]
=
CompiledResult
{
std
::
move
(
input
),
std
::
make_tuple
(
graph
,
shared_output
)};
}
}
}
PagedCompiler
::
Compiled
PagedCompiler
::
get_compiled
(
const
InfinilmModel
::
Input
&
input
)
{
if
(
model_
->
get_cache_config
()
!=
nullptr
&&
dynamic_cast
<
const
cache
::
PagedKVCacheConfig
*>
(
model_
->
get_cache_config
()))
{
size_t
batch_size
=
input
.
block_tables
.
value
()
->
size
(
0
);
size_t
block_per_req
=
input
.
block_tables
.
value
()
->
size
(
1
);
// only support decode only batch
if
(
batch_size
!=
input
.
input_ids
.
value
()
->
size
(
1
))
{
return
{
nullptr
,
nullptr
};
}
else
{
auto
result
=
compiled_map_decode_
.
find
(
batch_size
);
if
(
result
==
compiled_map_decode_
.
end
())
{
return
{
nullptr
,
nullptr
};
}
auto
&
graph_input
=
result
->
second
.
input
;
graph_input
.
input_ids
.
value
()
->
copy_from
(
input
.
input_ids
.
value
());
graph_input
.
position_ids
.
value
()
->
copy_from
(
input
.
position_ids
.
value
());
graph_input
.
total_sequence_lengths
.
value
()
->
copy_from
(
input
.
total_sequence_lengths
.
value
());
graph_input
.
input_offsets
.
value
()
->
copy_from
(
input
.
input_offsets
.
value
());
graph_input
.
block_tables
.
value
()
->
narrow
({{
1
,
0
,
block_per_req
}})
->
copy_from
(
input
.
block_tables
.
value
());
graph_input
.
slot_mapping
.
value
()
->
copy_from
(
input
.
slot_mapping
.
value
());
auto
graph
=
std
::
get
<
0
>
(
result
->
second
.
compiled
);
auto
shared_output
=
std
::
shared_ptr
<
InfinilmModel
::
Output
>
(
new
InfinilmModel
::
Output
{
std
::
get
<
1
>
(
result
->
second
.
compiled
)
->
logits
->
resume_from_blob_
()});
return
std
::
make_tuple
(
graph
,
shared_output
);
}
}
else
{
return
{
nullptr
,
nullptr
};
}
}
}
// namespace infinilm::engine
csrc/engine/compiler/paged_compiler.hpp
0 → 100644
View file @
21274f33
#pragma once
#include "graph_compiler.hpp"
#include <unordered_map>
namespace
infinilm
::
engine
{
class
PagedCompiler
:
public
GraphCompiler
{
public:
PagedCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
);
void
compile
()
override
;
Compiled
get_compiled
(
const
InfinilmModel
::
Input
&
input
)
override
;
private:
std
::
vector
<
size_t
>
decode_batch_sizes_
;
infinicore
::
Tensor
block_tables_holder_
;
struct
CompiledResult
{
InfinilmModel
::
Input
input
;
Compiled
compiled
;
};
std
::
unordered_map
<
size_t
,
// num_requests
CompiledResult
>
compiled_map_decode_
;
};
}
// namespace infinilm::engine
csrc/engine/compiler/static_batching_compiler.cpp
0 → 100644
View file @
21274f33
#include "static_batching_compiler.hpp"
#include "../../cache/cache.hpp"
namespace
infinilm
::
engine
{
StaticBatchingCompiler
::
StaticBatchingCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
GraphCompiler
(
model
)
{
}
void
StaticBatchingCompiler
::
compile
()
{
if
(
model_
->
get_cache_config
()
!=
nullptr
&&
dynamic_cast
<
const
cache
::
StaticKVCacheConfig
*>
(
model_
->
get_cache_config
()))
{
size_t
b
=
dynamic_cast
<
const
cache
::
StaticKVCacheConfig
*>
(
model_
->
get_cache_config
())
->
max_batch_size
();
InfinilmModel
::
Input
input
;
input
.
input_ids
=
infinicore
::
Tensor
::
empty
({
b
,
1
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
input
.
position_ids
=
infinicore
::
Tensor
::
empty
({
b
,
1
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
input
.
past_sequence_lengths
=
infinicore
::
Tensor
::
empty
({
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
input
.
total_sequence_lengths
=
infinicore
::
Tensor
::
empty
({
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
infinicore
::
context
::
startGraphRecording
();
auto
output
=
model_
->
forward
(
input
);
auto
graph
=
infinicore
::
context
::
stopGraphRecording
();
auto
shared_output
=
std
::
shared_ptr
<
InfinilmModel
::
Output
>
(
new
InfinilmModel
::
Output
{
infinicore
::
graph
::
GraphTensor
(
output
.
logits
)});
compiled_map_
[
std
::
make_tuple
(
b
,
1
)]
=
CompiledResult
{
std
::
move
(
input
),
std
::
make_tuple
(
graph
,
shared_output
)};
}
}
StaticBatchingCompiler
::
Compiled
StaticBatchingCompiler
::
get_compiled
(
const
InfinilmModel
::
Input
&
input
)
{
if
(
model_
->
get_cache_config
()
!=
nullptr
&&
dynamic_cast
<
const
cache
::
StaticKVCacheConfig
*>
(
model_
->
get_cache_config
()))
{
size_t
batch_size
=
input
.
input_ids
.
value
()
->
size
(
0
);
size_t
seqlen
=
input
.
input_ids
.
value
()
->
size
(
1
);
auto
result
=
compiled_map_
.
find
(
std
::
make_tuple
(
batch_size
,
seqlen
));
if
(
result
==
compiled_map_
.
end
())
{
return
std
::
make_tuple
(
nullptr
,
nullptr
);
}
else
{
auto
&
graph_input
=
result
->
second
.
input
;
graph_input
.
input_ids
.
value
()
->
copy_from
(
input
.
input_ids
.
value
());
graph_input
.
position_ids
.
value
()
->
copy_from
(
input
.
position_ids
.
value
());
graph_input
.
past_sequence_lengths
.
value
()
->
copy_from
(
input
.
past_sequence_lengths
.
value
());
graph_input
.
total_sequence_lengths
.
value
()
->
copy_from
(
input
.
total_sequence_lengths
.
value
());
auto
graph
=
std
::
get
<
0
>
(
result
->
second
.
compiled
);
auto
shared_output
=
std
::
shared_ptr
<
InfinilmModel
::
Output
>
(
new
InfinilmModel
::
Output
{
std
::
get
<
1
>
(
result
->
second
.
compiled
)
->
logits
->
resume_from_blob_
()});
return
std
::
make_tuple
(
graph
,
shared_output
);
}
}
else
{
return
std
::
make_tuple
(
nullptr
,
nullptr
);
}
}
}
// namespace infinilm::engine
csrc/engine/compiler/static_batching_compiler.hpp
0 → 100644
View file @
21274f33
#pragma once
#include "graph_compiler.hpp"
#include <unordered_map>
namespace
infinilm
::
engine
{
class
StaticBatchingCompiler
:
public
GraphCompiler
{
public:
StaticBatchingCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
);
void
compile
()
override
;
Compiled
get_compiled
(
const
InfinilmModel
::
Input
&
input
)
override
;
private:
struct
TupleHash
{
size_t
operator
()(
const
std
::
tuple
<
size_t
,
size_t
>
&
t
)
const
noexcept
{
auto
h1
=
std
::
hash
<
size_t
>
{}(
std
::
get
<
0
>
(
t
));
auto
h2
=
std
::
hash
<
size_t
>
{}(
std
::
get
<
1
>
(
t
));
return
h1
^
(
h2
+
0x9e3779b97f4a7c15ULL
+
(
h1
<<
6
)
+
(
h1
>>
2
));
}
};
struct
CompiledResult
{
InfinilmModel
::
Input
input
;
Compiled
compiled
;
};
std
::
unordered_map
<
std
::
tuple
<
size_t
,
size_t
>
,
// (batch_size, seq_len)
CompiledResult
,
TupleHash
>
compiled_map_
;
};
}
// namespace infinilm::engine
csrc/engine/infer_engine.cpp
View file @
21274f33
...
...
@@ -10,7 +10,8 @@ InferEngine::InferEngine(
const
InfinilmModel
::
Config
&
config
,
const
distributed
::
DistConfig
&
distributed_config
,
infinicore
::
Device
::
Type
device_type
,
const
cache
::
CacheConfig
*
cache_config
)
// Changed parameter
const
cache
::
CacheConfig
*
cache_config
,
bool
enable_graph_compiling
)
// Changed parameter
:
communication_group_
(
distributed_config
,
device_type
),
model_config_
(
config
)
{
...
...
@@ -24,7 +25,8 @@ InferEngine::InferEngine(
workers_
.
emplace_back
(
std
::
make_unique
<
RankWorker
>
(
model_config_
,
communication_group_
.
get_rank_info
(
r
),
cache_config_
!=
nullptr
?
cache_config_
.
get
()
:
nullptr
));
cache_config_
!=
nullptr
?
cache_config_
.
get
()
:
nullptr
,
enable_graph_compiling
));
}
}
...
...
csrc/engine/infer_engine.hpp
View file @
21274f33
...
...
@@ -22,7 +22,8 @@ public:
const
InfinilmModel
::
Config
&
config
,
const
distributed
::
DistConfig
&
distributed_config
=
distributed
::
DistConfig
(),
infinicore
::
Device
::
Type
device_type
=
infinicore
::
context
::
getDevice
().
getType
(),
const
cache
::
CacheConfig
*
cache_config
=
nullptr
);
const
cache
::
CacheConfig
*
cache_config
=
nullptr
,
bool
enable_graph_compiling
=
false
);
// Load a parameter to all workers (each can extract its shard inside RankWorker)
void
load_param
(
const
std
::
string
&
name
,
const
infinicore
::
Tensor
&
param
);
...
...
csrc/engine/rank_worker.cpp
View file @
21274f33
...
...
@@ -12,9 +12,11 @@ namespace infinilm::engine {
RankWorker
::
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
const
distributed
::
RankInfo
&
rank_info
,
const
cache
::
CacheConfig
*
cache_config
)
const
cache
::
CacheConfig
*
cache_config
,
bool
enable_graph_compiling
)
:
model_config_
(
model_config
),
rank_info_
(
rank_info
),
enable_graph_compiling_
(
enable_graph_compiling
),
job_cmd_
(
Command
::
INIT
),
has_job_
(
false
),
job_done_
(
false
),
...
...
@@ -180,6 +182,11 @@ void RankWorker::thread_loop() {
if
(
!
model_
)
{
throw
std
::
runtime_error
(
"Failed to create model"
);
}
if
(
enable_graph_compiling_
)
{
compiler_
=
std
::
make_unique
<
GeneralCompiler
>
(
model_
);
compiler_
->
compile
();
}
init_done_
=
true
;
}
cv_
.
notify_all
();
...
...
@@ -245,9 +252,21 @@ void RankWorker::thread_loop() {
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
auto
model_args
=
local_args
.
to_model_input
(
rank_info_
.
device
);
// Forward calculation
auto
logits
{
model_
->
forward
(
model_args
).
logits
};
infinicore
::
Tensor
logits
;
// Try to get compiled graph
if
(
compiler_
!=
nullptr
)
{
auto
[
graph
,
output
]
=
compiler_
->
get_compiled
(
local_args
.
to_model_input
(
infinicore
::
Device
::
cpu
()));
if
(
graph
!=
nullptr
&&
output
!=
nullptr
)
{
graph
->
run
();
logits
=
output
->
logits
;
}
}
// Fall back to eager mode
if
(
!
logits
)
{
auto
model_args
=
local_args
.
to_model_input
(
rank_info_
.
device
);
logits
=
model_
->
forward
(
model_args
).
logits
;
}
// Random sampling (rank 0 only)
if
(
rank_info_
.
tp_rank
==
0
)
{
auto
temperature
{
local_args
.
temperature
};
...
...
@@ -296,6 +315,9 @@ void RankWorker::thread_loop() {
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
try
{
model_
->
reset_cache
(
local_cache_config
!=
nullptr
?
local_cache_config
.
get
()
:
nullptr
);
if
(
compiler_
!=
nullptr
)
{
compiler_
->
compile
();
}
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
...
...
csrc/engine/rank_worker.hpp
View file @
21274f33
...
...
@@ -2,6 +2,7 @@
#include "../cache/cache.hpp"
#include "../models/model_factory.hpp"
#include "compiler/general_compiler.hpp"
#include "distributed/distributed.hpp"
#include <any>
...
...
@@ -55,7 +56,8 @@ public:
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
const
distributed
::
RankInfo
&
rank_info
,
const
cache
::
CacheConfig
*
cache_config
);
const
cache
::
CacheConfig
*
cache_config
,
bool
enable_graph_compiling
);
// Submit a parameter load job and wait until the load completes on the worker thread.
void
load_param
(
const
std
::
string
&
name
,
...
...
@@ -91,6 +93,10 @@ private:
std
::
shared_ptr
<
InfinilmModel
>
model_
;
std
::
shared_ptr
<
cache
::
Cache
>
cache_
;
// Graph Compiling
bool
enable_graph_compiling_
;
std
::
unique_ptr
<
GraphCompiler
>
compiler_
;
// Command for the pending job (protected by mutex_)
Command
job_cmd_
;
...
...
csrc/models/infinilm_model.hpp
View file @
21274f33
...
...
@@ -43,5 +43,6 @@ public:
virtual
Output
forward
(
const
Input
&
input
)
const
=
0
;
virtual
void
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
=
0
;
virtual
const
cache
::
CacheConfig
*
get_cache_config
()
const
=
0
;
};
}
// namespace infinilm
csrc/models/llama/llama_for_causal_lm.cpp
View file @
21274f33
...
...
@@ -45,7 +45,12 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
}
void
LlamaForCausalLM
::
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
{
model_
->
reset_cache
(
cache_config
);
cache_config_
=
cache_config
->
unique_copy
();
model_
->
reset_cache
(
cache_config_
.
get
());
}
const
cache
::
CacheConfig
*
LlamaForCausalLM
::
get_cache_config
()
const
{
return
cache_config_
.
get
();
}
}
// namespace infinilm::models::llama
csrc/models/llama/llama_for_causal_lm.hpp
View file @
21274f33
...
...
@@ -42,6 +42,8 @@ public:
void
reset_cache
(
const
cache
::
CacheConfig
*
cache_config
)
override
;
const
cache
::
CacheConfig
*
get_cache_config
()
const
override
;
// Module information
const
LlamaConfig
&
config
()
const
{
return
model_
->
config
();
}
LlamaModel
&
model
()
{
return
*
model_
;
}
...
...
@@ -53,6 +55,8 @@ protected:
// Language modeling head
INFINICORE_NN_MODULE
(
infinicore
::
nn
::
Linear
,
lm_head
);
std
::
unique_ptr
<
cache
::
CacheConfig
>
cache_config_
;
};
}
// namespace infinilm::models::llama
csrc/pybind11/engine/engine.hpp
View file @
21274f33
...
...
@@ -35,17 +35,20 @@ inline void bind_infer_engine(py::module &m) {
const
InfinilmModel
::
Config
&
cfg
,
const
distributed
::
DistConfig
&
dist
,
infinicore
::
Device
::
Type
dev
,
std
::
shared_ptr
<
const
infinilm
::
cache
::
CacheConfig
>
cache_cfg
)
{
std
::
shared_ptr
<
const
infinilm
::
cache
::
CacheConfig
>
cache_cfg
,
bool
enable_graph_compiling
)
{
return
std
::
make_shared
<
InferEngine
>
(
cfg
,
dist
,
dev
,
cache_cfg
?
cache_cfg
.
get
()
:
nullptr
);
cache_cfg
?
cache_cfg
.
get
()
:
nullptr
,
enable_graph_compiling
);
}),
py
::
arg
(
"config"
),
py
::
arg
(
"distributed_config"
)
=
distributed
::
DistConfig
(),
py
::
arg
(
"device_type"
)
=
infinicore
::
context
::
getDevice
().
getType
(),
py
::
arg
(
"cache_config"
)
=
py
::
none
())
py
::
arg
(
"cache_config"
)
=
py
::
none
(),
py
::
arg
(
"enable_graph_compiling"
)
=
false
)
.
def
(
"load_param"
,
&
InferEngine
::
load_param
,
py
::
arg
(
"name"
),
py
::
arg
(
"param"
),
"Load a parameter tensor into all workers (each worker picks its shard)"
)
...
...
examples/bench.py
View file @
21274f33
...
...
@@ -3,7 +3,7 @@ from transformers import AutoTokenizer
from
infinilm.modeling_utils
import
load_model_state_dict_by_file
from
infinilm.distributed
import
DistConfig
from
infinilm.infer_engine
import
GenerationConfig
,
InferEngine
from
infinilm.cache
import
StaticKVCacheConfig
from
infinilm.cache
import
StaticKVCacheConfig
,
PagedKVCacheConfig
import
argparse
import
sys
import
time
...
...
@@ -199,7 +199,16 @@ def get_args():
default
=
1.0
,
help
=
"sampling temperature"
,
)
parser
.
add_argument
(
"--enable-paged-attn"
,
action
=
"store_true"
,
help
=
"use paged cache"
,
)
parser
.
add_argument
(
"--enable-graph"
,
action
=
"store_true"
,
help
=
"enable graph compiling"
,
)
return
parser
.
parse_args
()
...
...
@@ -223,6 +232,8 @@ class TestModel:
infini_device
=
infinicore
.
device
(
"cpu"
,
0
),
tp
=
1
,
skip_load
=
False
,
cache_config
=
None
,
enable_graph
=
False
,
)
->
None
:
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
...
...
@@ -232,6 +243,8 @@ class TestModel:
model_path
,
device
=
infini_device
,
distributed_config
=
DistConfig
(
tp
),
cache_config
=
cache_config
,
enable_graph_compiling
=
enable_graph
,
)
# ---------------------------------------------------------------------------- #
...
...
@@ -336,6 +349,8 @@ if __name__ == "__main__":
batch_size
=
args
.
batch_size
input_len
=
args
.
input_len
output_len
=
args
.
output_len
enable_paged_attn
=
args
.
enable_paged_attn
enable_graph
=
args
.
enable_graph
if
isinstance
(
batch_size
,
int
):
batch_size
=
[
batch_size
]
...
...
@@ -350,13 +365,25 @@ if __name__ == "__main__":
# -------------------------------------------------------- #
# 测试
# -------------------------------------------------------- #
# print("=================== start test ====================", type(batch_size))
if
enable_paged_attn
:
paged_kv_block_size
=
16
max_num_blocks
=
max
(
[
((
c_
[
"input_len"
]
+
c_
[
"output_len"
]
+
15
)
//
16
)
*
c_
[
"batch_size"
]
for
_
,
c_
in
cases_dict
.
items
()
]
)
cache_config
=
PagedKVCacheConfig
(
max_num_blocks
,
paged_kv_block_size
)
else
:
cache_config
=
None
test
=
TestModel
(
model_path
,
infini_device
=
infini_device
,
tp
=
tp
,
skip_load
=
skip_load
,
cache_config
=
cache_config
,
enable_graph
=
enable_graph
,
)
for
idx
,
case
in
tqdm
(
cases_dict
.
items
(),
desc
=
"Processing cases"
):
...
...
@@ -366,13 +393,14 @@ if __name__ == "__main__":
input_len
=
case
[
"input_len"
]
output_len
=
case
[
"output_len"
]
# reset cache for each case
initial_capacity
=
input_len
+
output_len
test
.
model
.
reset_cache
(
StaticKVCacheConfig
(
max_batch_size
=
batch_size
,
max_cache_len
=
initial_capacity
if
not
enable_paged_attn
:
# reset cache if static kvcache is used
initial_capacity
=
input_len
+
output_len
test
.
model
.
reset_cache
(
StaticKVCacheConfig
(
max_batch_size
=
batch_size
,
max_cache_len
=
initial_capacity
)
)
)
# run test one case
test
.
run
(
...
...
examples/jiuge.py
View file @
21274f33
...
...
@@ -93,6 +93,11 @@ def get_args():
action
=
"store_true"
,
help
=
"use paged cache"
,
)
parser
.
add_argument
(
"--enable-graph"
,
action
=
"store_true"
,
help
=
"enable graph compiling"
,
)
parser
.
add_argument
(
"--top-k"
,
...
...
@@ -125,6 +130,7 @@ def test(
infini_device
=
infinicore
.
device
(
"cpu"
,
0
),
tp
=
1
,
enable_paged_attn
=
False
,
enable_graph
=
False
,
top_k
=
1
,
top_p
=
1.0
,
temperature
=
1.0
,
...
...
@@ -137,6 +143,7 @@ def test(
model_path
,
device
=
infini_device
,
distributed_config
=
DistConfig
(
tp
),
enable_graph_compiling
=
enable_graph
,
)
# ---------------------------------------------------------------------------- #
...
...
@@ -193,7 +200,7 @@ def test(
batch_size
=
1
if
prompts
is
str
else
len
(
prompts
)
max_total_tokens
=
max_new_tokens
+
len
(
input_ids_list
[
0
])
cache_config
=
PagedKVCacheConfig
(
num_blocks
=
(
max_total_tokens
//
16
+
1
)
*
batch_size
,
block_size
=
16
num_blocks
=
(
(
max_total_tokens
+
15
)
//
1
6
)
*
batch_size
,
block_size
=
16
)
else
:
batch_size
=
1
if
prompts
is
str
else
len
(
prompts
)
...
...
@@ -265,6 +272,7 @@ if __name__ == "__main__":
backend
=
args
.
backend
tp
=
args
.
tp
enable_paged_attn
=
args
.
enable_paged_attn
enable_graph
=
args
.
enable_graph
if
backend
!=
"cpp"
:
raise
ValueError
(
f
"Unsupported backend:
{
backend
}
."
)
...
...
@@ -277,6 +285,7 @@ if __name__ == "__main__":
infini_device
=
infini_device
,
tp
=
tp
,
enable_paged_attn
=
enable_paged_attn
,
enable_graph
=
enable_graph
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
,
temperature
=
args
.
temperature
,
...
...
python/infinilm/infer_engine.py
View file @
21274f33
...
...
@@ -28,6 +28,7 @@ class InferEngine(_infinilm.InferEngine):
device
=
None
,
distributed_config
=
DistConfig
(
1
),
cache_config
=
None
,
enable_graph_compiling
=
False
,
):
self
.
config
=
AutoConfig
.
from_pretrained
(
model_path
)
...
...
@@ -39,6 +40,7 @@ class InferEngine(_infinilm.InferEngine):
distributed_config
.
_underlying
,
device
.
_underlying
.
type
,
cache_config
,
enable_graph_compiling
,
)
self
.
use_cache
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment