Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
1f5ab1c5
Commit
1f5ab1c5
authored
Dec 04, 2025
by
PanZezhong
Browse files
issue/103 新版分布式推理基建
parent
36f8eab7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
267 additions
and
3 deletions
+267
-3
csrc/engine/distributed/communication_group.cpp
csrc/engine/distributed/communication_group.cpp
+37
-0
csrc/engine/distributed/communication_group.hpp
csrc/engine/distributed/communication_group.hpp
+34
-0
csrc/engine/distributed/dist_config.cpp
csrc/engine/distributed/dist_config.cpp
+35
-0
csrc/engine/distributed/dist_config.hpp
csrc/engine/distributed/dist_config.hpp
+31
-0
csrc/engine/distributed/distributed.hpp
csrc/engine/distributed/distributed.hpp
+4
-0
csrc/utils.hpp
csrc/utils.hpp
+124
-0
xmake.lua
xmake.lua
+2
-3
No files found.
csrc/engine/distributed/communication_group.cpp
0 → 100644
View file @
1f5ab1c5
#include "communication_group.hpp"
#include "../../utils.hpp"
namespace
infinilm
::
engine
::
distributed
{
CommunicationGroup
::
CommunicationGroup
(
const
DistConfig
&
dist_config
)
:
dist_config_
(
dist_config
),
communicators_
(
std
::
vector
<
infinicclComm_t
>
(
dist_config
.
tp_device_ids
.
size
(),
nullptr
))
{
if
(
dist_config_
.
tp_device_ids
.
size
()
>
1
)
{
RUN_INFINI
(
infinicclCommInitAll
(
(
infiniDevice_t
)
infinicore
::
context
::
getDevice
().
getType
(),
communicators_
.
data
(),
dist_config
.
tp_device_ids
.
size
(),
dist_config
.
tp_device_ids
.
data
()));
}
}
const
DistConfig
&
CommunicationGroup
::
getDistConfig
()
const
{
return
dist_config_
;
}
RankCommunicator
CommunicationGroup
::
getRankCommunicator
(
int
rank
)
const
{
RankCommunicator
rc
;
rc
.
info
=
dist_config_
.
getRankInfo
(
rank
);
rc
.
comm
=
communicators_
[
rank
];
return
rc
;
}
CommunicationGroup
::~
CommunicationGroup
()
{
if
(
communicators_
.
size
()
>
1
)
{
for
(
auto
&
comm
:
communicators_
)
{
RUN_INFINI
(
infinicclCommDestroy
(
comm
));
}
}
}
}
// namespace infinilm::engine::distributed
csrc/engine/distributed/communication_group.hpp
0 → 100644
View file @
1f5ab1c5
#pragma once
#include "dist_config.hpp"
#include <infiniccl.h>
#include <infinicore/context/context.hpp>
#include <vector>
namespace
infinilm
::
engine
::
distributed
{
// Communicator each rank will hold
struct
RankCommunicator
{
RankInfo
info
;
infinicclComm_t
comm
;
};
// The communication group managed by model infer engine
class
CommunicationGroup
{
public:
explicit
CommunicationGroup
(
const
DistConfig
&
dist_config
);
const
DistConfig
&
getDistConfig
()
const
;
RankCommunicator
getRankCommunicator
(
int
rank
)
const
;
~
CommunicationGroup
();
protected:
DistConfig
dist_config_
;
std
::
vector
<
infinicclComm_t
>
communicators_
;
};
}
// namespace infinilm::engine::distributed
csrc/engine/distributed/dist_config.cpp
0 → 100644
View file @
1f5ab1c5
#include "dist_config.hpp"
namespace
infinilm
::
engine
::
distributed
{
// ---------------- RankInfo ----------------
RankInfo
::
RankInfo
()
:
tp_size
(
1
),
tp_rank
(
0
),
device_id
(
0
)
{}
RankInfo
::
RankInfo
(
int
tp_size_
,
int
tp_rank_
,
int
device_id_
)
:
tp_size
(
tp_size_
),
tp_rank
(
tp_rank_
),
device_id
(
device_id_
)
{}
RankInfo
::
RankInfo
(
int
tp_size_
,
int
tp_rank_
)
:
RankInfo
(
tp_size_
,
tp_rank_
,
tp_rank_
)
{}
// ---------------- DistConfig ----------------
DistConfig
::
DistConfig
()
:
tp_device_ids
{
0
}
{}
DistConfig
::
DistConfig
(
int
tp_size
)
:
tp_device_ids
(
tp_size
,
0
)
{
for
(
int
i
=
0
;
i
<
tp_size
;
++
i
)
{
tp_device_ids
[
i
]
=
i
;
}
}
DistConfig
::
DistConfig
(
const
std
::
vector
<
int
>
&
tp_device_ids_
)
:
tp_device_ids
(
tp_device_ids_
)
{}
RankInfo
DistConfig
::
getRankInfo
(
int
rank
)
const
{
return
RankInfo
(
tp_device_ids
.
size
(),
rank
,
tp_device_ids
[
rank
]);
}
}
// namespace infinilm::engine::distributed
csrc/engine/distributed/dist_config.hpp
0 → 100644
View file @
1f5ab1c5
#pragma once
#include <vector>
namespace
infinilm
::
engine
::
distributed
{
struct
RankInfo
{
// Tensor parallelism size
int
tp_size
;
// Tensor parallelism rank number of this rank
int
tp_rank
;
// Device ID assigned to this rank
int
device_id
;
RankInfo
();
RankInfo
(
int
tp_size_
,
int
tp_rank_
,
int
device_id_
);
RankInfo
(
int
tp_size_
,
int
tp_rank_
);
};
struct
DistConfig
{
// Device IDs for each rank in tensor parallelism
std
::
vector
<
int
>
tp_device_ids
;
DistConfig
();
explicit
DistConfig
(
int
tp_size
);
explicit
DistConfig
(
const
std
::
vector
<
int
>
&
tp_device_ids_
);
RankInfo
getRankInfo
(
int
rank
)
const
;
};
}
// namespace infinilm::engine::distributed
csrc/engine/distributed/distributed.hpp
0 → 100644
View file @
1f5ab1c5
#pragma once
#include "communication_group.hpp"
#include "dist_config.hpp"
csrc/utils.hpp
0 → 100644
View file @
1f5ab1c5
#pragma once
#include <infinirt.h>
#include <cstring>
#include <iostream>
#include <stdio.h>
#include <stdlib.h>
inline
void
assertTrue
(
int
expr
,
const
char
*
msg
,
const
char
*
file
,
int
line
)
{
if
(
!
expr
)
{
fprintf
(
stderr
,
"
\033
[31mAssertion failed:
\033
[0m %s at file %s, line %d
\n
"
,
msg
,
file
,
line
);
exit
(
EXIT_FAILURE
);
}
}
#define ASSERT(expr) assertTrue((expr), #expr " is false", __FILE__, __LINE__)
#define ASSERT_EQ(a, b) assertTrue((a) == (b), #a " != " #b, __FILE__, __LINE__)
#define ASSERT_VALID_PTR(a) assertTrue((a) != nullptr, #a " is nullptr", __FILE__, __LINE__)
#define PANIC(EXPR) \
printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \
exit(EXIT_FAILURE)
#define RUN_INFINI(API) \
do { \
auto api_result_ = (API); \
if (api_result_ != INFINI_STATUS_SUCCESS) { \
std::cerr << "Error Code " << api_result_ << " in `" << #API << "`" \
<< " from " << __func__ \
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
} while (0)
inline
float
f16_to_f32
(
uint16_t
h
)
{
uint32_t
sign
=
(
h
&
0x8000
)
<<
16
;
// Extract the sign bit
int32_t
exponent
=
(
h
>>
10
)
&
0x1F
;
// Extract the exponent
uint32_t
mantissa
=
h
&
0x3FF
;
// Extract the mantissa (fraction part)
if
(
exponent
==
31
)
{
// Special case for Inf and NaN
if
(
mantissa
!=
0
)
{
// NaN: Set float32 NaN
uint32_t
f32
=
sign
|
0x7F800000
|
(
mantissa
<<
13
);
return
*
(
float
*
)
&
f32
;
}
else
{
// Infinity
uint32_t
f32
=
sign
|
0x7F800000
;
return
*
(
float
*
)
&
f32
;
}
}
else
if
(
exponent
==
0
)
{
// Subnormal float16 or zero
if
(
mantissa
==
0
)
{
// Zero (positive or negative)
uint32_t
f32
=
sign
;
// Just return signed zero
return
*
(
float
*
)
&
f32
;
}
else
{
// Subnormal: Convert to normalized float32
exponent
=
-
14
;
// Set exponent for subnormal numbers
while
((
mantissa
&
0x400
)
==
0
)
{
// Normalize mantissa
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
0x3FF
;
// Clear the leading 1 bit
uint32_t
f32
=
sign
|
((
exponent
+
127
)
<<
23
)
|
(
mantissa
<<
13
);
return
*
(
float
*
)
&
f32
;
}
}
else
{
// Normalized float16
uint32_t
f32
=
sign
|
((
exponent
+
127
-
15
)
<<
23
)
|
(
mantissa
<<
13
);
return
*
(
float
*
)
&
f32
;
}
}
inline
uint16_t
f32_to_f16
(
float
val
)
{
uint32_t
f32
;
memcpy
(
&
f32
,
&
val
,
sizeof
(
f32
));
// Read the bits of the float32
uint16_t
sign
=
(
f32
>>
16
)
&
0x8000
;
// Extract the sign bit
int32_t
exponent
=
((
f32
>>
23
)
&
0xFF
)
-
127
;
// Extract and de-bias the exponent
uint32_t
mantissa
=
f32
&
0x7FFFFF
;
// Extract the mantissa (fraction part)
if
(
exponent
>=
31
)
{
// Special cases for Inf and NaN
// NaN
if
(
exponent
==
128
&&
mantissa
!=
0
)
{
return
static_cast
<
uint16_t
>
(
sign
|
0x7E00
);
}
// Infinity
return
static_cast
<
uint16_t
>
(
sign
|
0x7C00
);
}
else
if
(
exponent
>=
-
14
)
{
// Normalized case
return
(
uint16_t
)(
sign
|
((
exponent
+
15
)
<<
10
)
|
(
mantissa
>>
13
));
}
else
if
(
exponent
>=
-
24
)
{
mantissa
|=
0x800000
;
// Add implicit leading 1
mantissa
>>=
(
-
14
-
exponent
);
return
(
uint16_t
)(
sign
|
(
mantissa
>>
13
));
}
else
{
// Too small for subnormal: return signed zero
return
(
uint16_t
)
sign
;
}
}
inline
float
bf16_to_f32
(
uint16_t
val
)
{
// 只需把 bf16 放到 float32 高 16 bit,其余 16 位置 0。
uint32_t
bits32
=
static_cast
<
uint32_t
>
(
val
)
<<
16
;
float
out
;
std
::
memcpy
(
&
out
,
&
bits32
,
sizeof
(
out
));
return
out
;
}
inline
uint16_t
f32_to_bf16
(
float
val
)
{
uint32_t
bits32
;
std
::
memcpy
(
&
bits32
,
&
val
,
sizeof
(
bits32
));
// 截断前先加 0x7FFF,再根据第 16 位(有效位的最低位)的奇偶做 round-to-nearest-even
const
uint32_t
rounding_bias
=
0x00007FFF
+
// 0111 1111 1111 1111
((
bits32
>>
16
)
&
1
);
// 尾数的有效位的最低位奇数时 +1,即实现舍入偶数
uint16_t
bf16_bits
=
static_cast
<
uint16_t
>
((
bits32
+
rounding_bias
)
>>
16
);
return
bf16_bits
;
}
// Hash combine utility (similar to boost::hash_combine)
inline
void
hash_combine
(
size_t
&
seed
,
size_t
value
)
{
seed
^=
value
+
0x9e3779b9
+
(
seed
<<
6
)
+
(
seed
>>
2
);
}
xmake.lua
View file @
1f5ab1c5
...
...
@@ -50,9 +50,8 @@ target("_infinilm")
add_linkdirs
(
INFINI_ROOT
..
"/lib"
)
add_links
(
"infinicore_cpp_api"
,
"infiniop"
,
"infinirt"
,
"infiniccl"
)
-- Add Llama model files
add_files
(
"csrc/models/*/*.cpp"
)
add_files
(
"csrc/pybind11/bindings.cc"
)
-- Add src files
add_files
(
"csrc/**.cpp"
)
set_installdir
(
"python/infinilm"
)
target_end
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment