Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
2c925eb4
Commit
2c925eb4
authored
Jan 30, 2026
by
PanZezhong
Committed by
wooway777
Feb 10, 2026
Browse files
issue/143 add barrier for compilers
parent
429f54cd
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
123 additions
and
18 deletions
+123
-18
csrc/engine/compiler/general_compiler.cpp
csrc/engine/compiler/general_compiler.cpp
+3
-3
csrc/engine/compiler/general_compiler.hpp
csrc/engine/compiler/general_compiler.hpp
+1
-1
csrc/engine/compiler/graph_compiler.hpp
csrc/engine/compiler/graph_compiler.hpp
+3
-1
csrc/engine/compiler/paged_compiler.cpp
csrc/engine/compiler/paged_compiler.cpp
+5
-2
csrc/engine/compiler/paged_compiler.hpp
csrc/engine/compiler/paged_compiler.hpp
+1
-1
csrc/engine/compiler/static_batching_compiler.cpp
csrc/engine/compiler/static_batching_compiler.cpp
+5
-2
csrc/engine/compiler/static_batching_compiler.hpp
csrc/engine/compiler/static_batching_compiler.hpp
+1
-1
csrc/engine/infer_engine.cpp
csrc/engine/infer_engine.cpp
+16
-2
csrc/engine/infer_engine.hpp
csrc/engine/infer_engine.hpp
+4
-0
csrc/engine/rank_barrier.cpp
csrc/engine/rank_barrier.cpp
+19
-0
csrc/engine/rank_barrier.hpp
csrc/engine/rank_barrier.hpp
+20
-0
csrc/engine/rank_worker.cpp
csrc/engine/rank_worker.cpp
+37
-5
csrc/engine/rank_worker.hpp
csrc/engine/rank_worker.hpp
+8
-0
No files found.
csrc/engine/compiler/general_compiler.cpp
View file @
2c925eb4
#include "general_compiler.hpp"
namespace
infinilm
::
engine
{
GeneralCompiler
::
GeneralCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
GraphCompiler
(
model
)
{
static_batching_compiler_
=
std
::
make_unique
<
StaticBatchingCompiler
>
(
model_
);
paged_compiler_
=
std
::
make_unique
<
PagedCompiler
>
(
model_
);
GeneralCompiler
::
GeneralCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
,
RankBarrier
*
barrier
)
:
GraphCompiler
(
model
,
barrier
)
{
static_batching_compiler_
=
std
::
make_unique
<
StaticBatchingCompiler
>
(
model_
,
barrier
);
paged_compiler_
=
std
::
make_unique
<
PagedCompiler
>
(
model_
,
barrier
);
}
void
GeneralCompiler
::
compile
()
{
...
...
csrc/engine/compiler/general_compiler.hpp
View file @
2c925eb4
...
...
@@ -6,7 +6,7 @@
namespace
infinilm
::
engine
{
class
GeneralCompiler
:
public
GraphCompiler
{
public:
GeneralCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
);
GeneralCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
,
RankBarrier
*
barrier
);
void
compile
()
override
;
...
...
csrc/engine/compiler/graph_compiler.hpp
View file @
2c925eb4
#pragma once
#include "../../models/infinilm_model.hpp"
#include "../rank_barrier.hpp"
namespace
infinilm
::
engine
{
...
...
@@ -10,7 +11,7 @@ public:
std
::
shared_ptr
<
infinicore
::
graph
::
Graph
>
,
std
::
shared_ptr
<
InfinilmModel
::
Output
>>
;
explicit
GraphCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
model_
(
model
)
{}
explicit
GraphCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
,
RankBarrier
*
barrier
)
:
model_
(
model
),
barrier_
(
barrier
)
{}
virtual
~
GraphCompiler
()
=
default
;
virtual
void
compile
()
=
0
;
...
...
@@ -18,6 +19,7 @@ public:
protected:
std
::
shared_ptr
<
InfinilmModel
>
model_
;
RankBarrier
*
barrier_
;
};
}
// namespace infinilm::engine
csrc/engine/compiler/paged_compiler.cpp
View file @
2c925eb4
#include "paged_compiler.hpp"
namespace
infinilm
::
engine
{
PagedCompiler
::
PagedCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
GraphCompiler
(
model
)
{
PagedCompiler
::
PagedCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
,
RankBarrier
*
barrier
)
:
GraphCompiler
(
model
,
barrier
)
{
for
(
size_t
b
=
1
;
b
<
32
;
b
++
)
{
decode_batch_sizes_
.
push_back
(
b
);
}
...
...
@@ -43,9 +43,12 @@ void PagedCompiler::compile() {
infinicore
::
context
::
memcpyH2D
(
input
.
input_offsets
.
value
()
->
data
(),
input_offsets_vec
.
data
(),
(
b
+
1
)
*
sizeof
(
int64_t
),
false
);
input
.
block_tables
=
block_tables_holder_
->
as_strided
({
b
,
block_per_req
},
{(
ptrdiff_t
)
block_per_req
,
1
});
input
.
slot_mapping
=
infinicore
::
Tensor
::
empty
({
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
barrier_
->
wait
();
infinicore
::
context
::
startGraphRecording
();
auto
output
=
model_
->
forward
(
input
);
auto
graph
=
infinicore
::
context
::
stopGraphRecording
();
barrier_
->
wait
();
auto
shared_output
=
std
::
shared_ptr
<
InfinilmModel
::
Output
>
(
new
InfinilmModel
::
Output
{
infinicore
::
graph
::
GraphTensor
(
output
.
logits
)});
...
...
csrc/engine/compiler/paged_compiler.hpp
View file @
2c925eb4
...
...
@@ -7,7 +7,7 @@
namespace
infinilm
::
engine
{
class
PagedCompiler
:
public
GraphCompiler
{
public:
PagedCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
);
PagedCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
,
RankBarrier
*
barrier
);
void
compile
()
override
;
...
...
csrc/engine/compiler/static_batching_compiler.cpp
View file @
2c925eb4
...
...
@@ -3,8 +3,8 @@
#include "../../cache/cache.hpp"
namespace
infinilm
::
engine
{
StaticBatchingCompiler
::
StaticBatchingCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
)
:
GraphCompiler
(
model
)
{
StaticBatchingCompiler
::
StaticBatchingCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
,
RankBarrier
*
barrier
)
:
GraphCompiler
(
model
,
barrier
)
{
}
void
StaticBatchingCompiler
::
compile
()
{
...
...
@@ -17,9 +17,12 @@ void StaticBatchingCompiler::compile() {
input
.
total_sequence_lengths
=
infinicore
::
Tensor
::
empty
({
b
},
infinicore
::
DataType
::
I64
,
infinicore
::
context
::
getDevice
());
std
::
vector
<
int64_t
>
total_sequence_lengths_vec
(
b
,
1
);
infinicore
::
context
::
memcpyH2D
(
input
.
total_sequence_lengths
.
value
()
->
data
(),
total_sequence_lengths_vec
.
data
(),
b
*
sizeof
(
int64_t
),
false
);
barrier_
->
wait
();
infinicore
::
context
::
startGraphRecording
();
auto
output
=
model_
->
forward
(
input
);
auto
graph
=
infinicore
::
context
::
stopGraphRecording
();
barrier_
->
wait
();
auto
shared_output
=
std
::
shared_ptr
<
InfinilmModel
::
Output
>
(
new
InfinilmModel
::
Output
{
infinicore
::
graph
::
GraphTensor
(
output
.
logits
)});
...
...
csrc/engine/compiler/static_batching_compiler.hpp
View file @
2c925eb4
...
...
@@ -7,7 +7,7 @@
namespace
infinilm
::
engine
{
class
StaticBatchingCompiler
:
public
GraphCompiler
{
public:
StaticBatchingCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
);
StaticBatchingCompiler
(
const
std
::
shared_ptr
<
InfinilmModel
>
&
model
,
RankBarrier
*
barrier
);
void
compile
()
override
;
...
...
csrc/engine/infer_engine.cpp
View file @
2c925eb4
...
...
@@ -20,12 +20,14 @@ InferEngine::InferEngine(
}
// Create one RankWorker per rank
int
world_size
=
communication_group_
.
get_world_size
();
barrier_
=
std
::
make_unique
<
RankBarrier
>
((
size_t
)
world_size
);
workers_
.
reserve
(
world_size
);
for
(
int
r
=
0
;
r
<
world_size
;
++
r
)
{
workers_
.
emplace_back
(
std
::
make_unique
<
RankWorker
>
(
model_config_
,
communication_group_
.
get_rank_info
(
r
),
cache_config_
!=
nullptr
?
cache_config_
.
get
()
:
nullptr
,
barrier_
.
get
(),
enable_graph_compiling
));
}
}
...
...
@@ -67,9 +69,9 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
};
return
{
input_ids
,
// @todo: on device in the future
to_device
(
input_ids
)
,
// @todo: on device in the future
to_device
(
position_ids
),
past_sequence_lengths
,
// @todo: on device in the future
to_device
(
past_sequence_lengths
)
,
// @todo: on device in the future
to_device
(
total_sequence_lengths
),
to_device
(
input_offsets
),
to_device
(
block_tables
),
...
...
@@ -90,6 +92,16 @@ InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
return
workers_
[
0
]
->
get_output
();
}
void
InferEngine
::
compile
()
{
for
(
auto
&
worker
:
workers_
)
{
worker
->
compile
();
}
// Wait for all workers
for
(
auto
&
worker
:
workers_
)
{
worker
->
wait
();
}
}
//------------------------------------------------------
// Destructor
//------------------------------------------------------
...
...
@@ -114,6 +126,8 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
for
(
auto
&
worker
:
workers_
)
{
worker
->
wait
();
}
this
->
compile
();
}
}
// namespace infinilm::engine
csrc/engine/infer_engine.hpp
View file @
2c925eb4
...
...
@@ -4,6 +4,7 @@
#include "../models/llama/llama_config.hpp"
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_barrier.hpp"
#include "rank_worker.hpp"
#include <optional>
...
...
@@ -34,6 +35,8 @@ public:
// Run a single forward pass on all workers and return the outputs from all ranks
Output
forward
(
const
Input
&
input
);
void
compile
();
void
reset_cache
(
const
cache
::
CacheConfig
*
new_config
);
~
InferEngine
();
...
...
@@ -45,6 +48,7 @@ public:
protected:
std
::
vector
<
std
::
unique_ptr
<
RankWorker
>>
workers_
;
std
::
unique_ptr
<
RankBarrier
>
barrier_
;
distributed
::
CommunicationGroup
communication_group_
;
const
InfinilmModel
::
Config
&
model_config_
;
std
::
unique_ptr
<
cache
::
CacheConfig
>
cache_config_
;
...
...
csrc/engine/rank_barrier.cpp
0 → 100644
View file @
2c925eb4
#include "rank_barrier.hpp"
namespace
infinilm
::
engine
{
RankBarrier
::
RankBarrier
(
size_t
num_ranks
)
:
thread_count_
(
num_ranks
),
generation_
(
0
),
arrived_
(
0
)
{}
void
RankBarrier
::
wait
()
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
mutex_
);
int
gen
=
generation_
;
if
(
++
arrived_
==
thread_count_
)
{
// last thread
generation_
++
;
arrived_
=
0
;
cv_
.
notify_all
();
}
else
{
cv_
.
wait
(
lock
,
[
&
]
{
return
gen
!=
generation_
;
});
}
}
}
// namespace infinilm::engine
csrc/engine/rank_barrier.hpp
0 → 100644
View file @
2c925eb4
#pragma once
#include <condition_variable>
#include <mutex>
namespace
infinilm
::
engine
{
class
RankBarrier
{
public:
explicit
RankBarrier
(
size_t
nranks
);
void
wait
();
private:
const
size_t
thread_count_
;
size_t
arrived_
;
size_t
generation_
;
std
::
mutex
mutex_
;
std
::
condition_variable
cv_
;
};
}
// namespace infinilm::engine
csrc/engine/rank_worker.cpp
View file @
2c925eb4
...
...
@@ -13,6 +13,7 @@ namespace infinilm::engine {
RankWorker
::
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
const
distributed
::
RankInfo
&
rank_info
,
const
cache
::
CacheConfig
*
cache_config
,
RankBarrier
*
barrier
,
bool
enable_graph_compiling
)
:
model_config_
(
model_config
),
rank_info_
(
rank_info
),
...
...
@@ -22,7 +23,8 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
job_done_
(
false
),
should_exit_
(
false
),
init_done_
(
false
),
rng_
(
std
::
random_device
{}())
{
rng_
(
std
::
random_device
{}()),
barrier_
(
barrier
)
{
if
(
cache_config
!=
nullptr
)
{
pending_cache_config_
=
cache_config
->
unique_copy
();
}
...
...
@@ -115,6 +117,21 @@ void RankWorker::run(const Input &args) {
cv_
.
notify_all
();
}
//------------------------------------------------------
// compile -- asynchronous
//------------------------------------------------------
void
RankWorker
::
compile
()
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_
);
if
(
should_exit_
)
{
throw
std
::
runtime_error
(
"RankWorker is closing; cannot run"
);
}
job_cmd_
=
Command
::
COMPILE
;
has_job_
=
true
;
job_done_
=
false
;
cv_
.
notify_all
();
}
//------------------------------------------------------
// wait -- asynchronous
//------------------------------------------------------
...
...
@@ -183,8 +200,7 @@ void RankWorker::thread_loop() {
throw
std
::
runtime_error
(
"Failed to create model"
);
}
if
(
enable_graph_compiling_
)
{
compiler_
=
std
::
make_unique
<
GeneralCompiler
>
(
model_
);
compiler_
->
compile
();
compiler_
=
std
::
make_unique
<
GeneralCompiler
>
(
model_
,
barrier_
);
}
init_done_
=
true
;
...
...
@@ -315,10 +331,25 @@ void RankWorker::thread_loop() {
}
else
if
(
local_cmd
==
Command
::
RESET_CACHE
)
{
try
{
model_
->
reset_cache
(
local_cache_config
!=
nullptr
?
local_cache_config
.
get
()
:
nullptr
);
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
job_done_
=
true
;
}
cv_
.
notify_all
();
}
catch
(
const
std
::
exception
&
e
)
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
should_exit_
=
true
;
job_done_
=
true
;
cv_
.
notify_all
();
spdlog
::
error
(
"[{}] exception during reset_cache: {}
\n
"
,
info
(),
e
.
what
());
break
;
}
}
else
if
(
local_cmd
==
Command
::
COMPILE
)
{
try
{
if
(
compiler_
!=
nullptr
)
{
compiler_
->
compile
();
}
{
std
::
lock_guard
<
std
::
mutex
>
lk
(
mutex_
);
job_done_
=
true
;
...
...
@@ -330,9 +361,10 @@ void RankWorker::thread_loop() {
should_exit_
=
true
;
job_done_
=
true
;
cv_
.
notify_all
();
spdlog
::
error
(
"[{}] exception during
reset_cach
e: {}
\n
"
,
info
(),
e
.
what
());
spdlog
::
error
(
"[{}] exception during
compil
e: {}
\n
"
,
info
(),
e
.
what
());
break
;
}
}
else
{
// Shouldn't reach here (no-op)
}
...
...
csrc/engine/rank_worker.hpp
View file @
2c925eb4
...
...
@@ -4,6 +4,7 @@
#include "../models/model_factory.hpp"
#include "compiler/general_compiler.hpp"
#include "distributed/distributed.hpp"
#include "rank_barrier.hpp"
#include <any>
#include <condition_variable>
...
...
@@ -21,6 +22,7 @@ class RankWorker {
LOAD
,
RUN
,
RESET_CACHE
,
COMPILE
,
STOP
};
...
...
@@ -57,6 +59,7 @@ public:
RankWorker
(
const
InfinilmModel
::
Config
&
model_config
,
const
distributed
::
RankInfo
&
rank_info
,
const
cache
::
CacheConfig
*
cache_config
,
RankBarrier
*
barrier
,
bool
enable_graph_compiling
);
// Submit a parameter load job and wait until the load completes on the worker thread.
...
...
@@ -72,6 +75,9 @@ public:
// Reset the internal cache with a new configuration
void
reset_cache
(
const
cache
::
CacheConfig
*
new_config
);
// Compile the model graph if enabled.
void
compile
();
// Wait until run job completes. The result can be retrieved with get_output().
void
wait
();
...
...
@@ -122,6 +128,8 @@ private:
// Random
std
::
mt19937
rng_
;
RankBarrier
*
barrier_
;
};
}
// namespace infinilm::engine
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment