Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
lishen01
Sccl
Commits
a4ac3320
Commit
a4ac3320
authored
Jul 07, 2025
by
lishen
Browse files
通过线程池实现ipcsocket,满足节点内通信
parent
d9d23f34
Changes
132
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
393 additions
and
71 deletions
+393
-71
src/include/base.h
src/include/base.h
+7
-5
src/include/check.h
src/include/check.h
+46
-9
src/include/debug.h
src/include/debug.h
+57
-39
src/utils/align.h
src/utils/align.h
+3
-0
src/utils/alloc.h
src/utils/alloc.h
+33
-6
src/utils/archinfo.cpp
src/utils/archinfo.cpp
+0
-0
src/utils/archinfo.h
src/utils/archinfo.h
+1
-4
src/utils/asm_ops.h
src/utils/asm_ops.h
+102
-0
src/utils/thread_pool.cpp
src/utils/thread_pool.cpp
+83
-0
src/utils/thread_pool.h
src/utils/thread_pool.h
+53
-0
src/utils/utils.cpp
src/utils/utils.cpp
+7
-7
src/utils/utils.h
src/utils/utils.h
+1
-1
No files found.
src/include/base.h
View file @
a4ac3320
...
...
@@ -12,7 +12,6 @@
外部环境变量设置:
src/debug.h:
SCCL_DEBUG_LEVEL、SCCL_DEBUG_POS
*/
namespace
sccl
{
#define WARP_SIZE warpSize
...
...
@@ -47,9 +46,12 @@ typedef enum : uint8_t {
}
scclProtocolType_t
;
// 每个进程的唯一ID
struct
scclUniqueId
{
int
rank
;
// 当前节点的全局排名
int
nRanks
;
// 总的节点数量
}
struct
scclRankInfo
{
int
rank
;
// 当前节点的全局排名
int
nRanks
;
// 总的节点数量
int
localRank
;
// 当前节点的本地rank
int
localRanks
;
// 当前节点的本地rank数量
int
hipDev
;
// CUDA 设备 ID
};
}
// namespace sccl
src/include/check.h
View file @
a4ac3320
...
...
@@ -27,15 +27,15 @@ namespace sccl {
* @note 根据代码作用域(如公开API或内部实现)编写适当的文档注释
*/
typedef
enum
{
scclSuccess
=
0
,
/*!<
No error
*/
scclUnhandledHipError
=
1
,
/*!<
Unhandled HIP error
*/
scclSystemError
=
2
,
/*!<
Unhandled system error
*/
scclInternalError
=
3
,
/*!<
Internal Error - Please report to RCCL developers
*/
scclInvalidArgument
=
4
,
/*!<
Invalid argument
*/
scclInvalidUsage
=
5
,
/*!<
Invalid usage
*/
scclRemoteError
=
6
,
/*!<
Remote process exited or there was a network error
*/
scclInProgress
=
7
,
/*!< RCCL
operation in progress
*/
scclNumResults
=
8
/*!<
Number of result types
*/
scclSuccess
=
0
,
/*!<
无错误
*/
scclUnhandledHipError
=
1
,
/*!<
未处理的 HIP 错误
*/
scclSystemError
=
2
,
/*!<
未处理的系统错误
*/
scclInternalError
=
3
,
/*!<
内部错误 - 请报告给 RCCL 开发者
*/
scclInvalidArgument
=
4
,
/*!<
无效参数
*/
scclInvalidUsage
=
5
,
/*!<
无效使用
*/
scclRemoteError
=
6
,
/*!<
远程进程退出或发生网络错误
*/
scclInProgress
=
7
,
/*!< RCCL
操作正在进行中
*/
scclNumResults
=
8
/*!<
结果类型数量
*/
}
scclResult_t
;
typedef
enum
{
...
...
@@ -125,6 +125,43 @@ static const char* scclGetErrorString(scclResult_t code) {
} \
} while(0);
#define NEQCHECK(statement, value) \
do { \
if((statement) != value) { \
/* Print the back trace*/
\
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, scclSystemError, strerror(errno)); \
return scclSystemError; \
} \
} while(0);
#define NEQCHECKGOTO(statement, value, RES, label) \
do { \
if((statement) != value) { \
/* Print the back trace*/
\
RES = scclSystemError; \
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, RES, strerror(errno)); \
goto label; \
} \
} while(0);
#define LECHECK(statement, value) \
do { \
if((statement) <= value) { \
/* Print the back trace*/
\
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, scclSystemError, strerror(errno)); \
return scclSystemError; \
} \
} while(0);
#define LTCHECK(statement, value) \
do { \
if((statement) < value) { \
/* Print the back trace*/
\
INFO(SCCL_LOG_CODEALL, "%s:%d -> %d (%s)", __FILE__, __LINE__, scclSystemError, strerror(errno)); \
return scclSystemError; \
} \
} while(0);
////////////////////////////// SYS //////////////////////////////
// Check system calls
...
...
src/include/debug.h
View file @
a4ac3320
...
...
@@ -26,18 +26,18 @@ typedef enum : uint8_t {
SCCL_LOG_ABORT
=
4
}
scclDebugLogLevel_t
;
typedef
enum
:
u
int
8
_t
{
SCCL_LOG_CODEALL
=
0
,
SCCL_LOG_NET
=
1
,
SCCL_LOG_TOPO
=
2
,
SCCL_LOG_BOOTSTRAP
=
3
,
SCCL_LOG_TRANSPORT
=
4
,
SCCL_LOG_GRAPH
=
5
,
SCCL_LOG_CONNECT
=
6
,
SCCL_LOG_P2P
=
7
,
SCCL_LOG_COLLECTIVE
=
8
,
SCCL_LOG_ALLOC
=
9
}
scclDebugLog
Po
s_t
;
typedef
enum
:
int
64
_t
{
SCCL_LOG_CODEALL
=
~
0
,
SCCL_LOG_NET
=
0x000
1
,
SCCL_LOG_TOPO
=
0x000
2
,
SCCL_LOG_BOOTSTRAP
=
0x0004
,
SCCL_LOG_TRANSPORT
=
0x0008
,
SCCL_LOG_GRAPH
=
0x0010
,
SCCL_LOG_CONNECT
=
0x0020
,
SCCL_LOG_P2P
=
0x0040
,
SCCL_LOG_COLLECTIVE
=
0x0080
,
SCCL_LOG_ALLOC
=
0x0100
}
scclDebugLog
SubSy
s_t
;
namespace
debug
{
...
...
@@ -48,7 +48,8 @@ static __thread int tid = -1; // 线程局
static
int
pid
=
-
1
;
// 存储当前进程的ID,默认值为-1
static
FILE
*
scclDebugFile
=
stdout
;
// 指向调试输出流的文件指针,默认指向标准输出(stdout
static
int
scclDebugLevel
=
-
1
;
// 初始化为 -1,表示未设置
static
uint64_t
scclDebugMask
=
SCCL_LOG_TOPO
|
SCCL_LOG_BOOTSTRAP
;
// Default debug sub-system mask is INIT and ENV
static
int
scclDebugLevel
=
-
1
;
// 初始化为 -1,表示未设置
// 在文件顶部或适当位置定义变量
static
int
scclDebugPos
=
-
1
;
// 初始化为 -1,表示未设置
...
...
@@ -116,30 +117,48 @@ static void scclDebugInit() {
}
//// 按照代码位置划分
int
tempScclDebugPos
=
-
1
;
{
const
char
*
sccl_debug
=
getenv
(
"SCCL_DEBUG_POS"
);
if
(
sccl_debug
==
NULL
)
{
tempScclDebugPos
=
SCCL_LOG_CODEALL
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"NET"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_NET
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"TOPO"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_TOPO
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"BOOTSTRAP"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_BOOTSTRAP
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"TRANSPORT"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_TRANSPORT
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"GRAPH"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_GRAPH
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"CONNECT"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_CONNECT
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"P2P"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_P2P
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"COLLECTIVE"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_COLLECTIVE
;
}
else
if
(
strcasecmp
(
sccl_debug
,
"ALLOC"
)
==
0
)
{
tempScclDebugPos
=
SCCL_LOG_ALLOC
;
char
*
scclDebugSubsysEnv
=
getenv
(
"SCCL_DEBUG_SUBSYS"
);
if
(
scclDebugSubsysEnv
!=
NULL
)
{
int
invert
=
0
;
if
(
scclDebugSubsysEnv
[
0
]
==
'^'
)
{
invert
=
1
;
scclDebugSubsysEnv
++
;
}
scclDebugMask
=
invert
?
~
0ULL
:
0ULL
;
char
*
scclDebugSubsys
=
strdup
(
scclDebugSubsysEnv
);
char
*
subsys
=
strtok
(
scclDebugSubsys
,
","
);
while
(
subsys
!=
NULL
)
{
uint64_t
mask
=
0
;
if
(
strcasecmp
(
subsys
,
"NET"
)
==
0
)
{
mask
=
SCCL_LOG_NET
;
}
else
if
(
strcasecmp
(
subsys
,
"TOPO"
)
==
0
)
{
mask
=
SCCL_LOG_TOPO
;
}
else
if
(
strcasecmp
(
subsys
,
"BOOTSTRAP"
)
==
0
)
{
mask
=
SCCL_LOG_BOOTSTRAP
;
}
else
if
(
strcasecmp
(
subsys
,
"TRANSPORT"
)
==
0
)
{
mask
=
SCCL_LOG_TRANSPORT
;
}
else
if
(
strcasecmp
(
subsys
,
"GRAPH"
)
==
0
)
{
mask
=
SCCL_LOG_GRAPH
;
}
else
if
(
strcasecmp
(
subsys
,
"CONNECT"
)
==
0
)
{
mask
=
SCCL_LOG_CONNECT
;
}
else
if
(
strcasecmp
(
subsys
,
"P2P"
)
==
0
)
{
mask
=
SCCL_LOG_P2P
;
}
else
if
(
strcasecmp
(
subsys
,
"COLLECTIVE"
)
==
0
)
{
mask
=
SCCL_LOG_COLLECTIVE
;
}
else
if
(
strcasecmp
(
subsys
,
"ALLOC"
)
==
0
)
{
mask
=
SCCL_LOG_ALLOC
;
}
else
if
(
strcasecmp
(
subsys
,
"ALL"
)
==
0
)
{
mask
=
SCCL_LOG_CODEALL
;
}
if
(
mask
)
{
if
(
invert
)
scclDebugMask
&=
~
mask
;
else
scclDebugMask
|=
mask
;
}
subsys
=
strtok
(
NULL
,
","
);
}
free
(
scclDebugSubsys
);
}
// Cache pid and hostname
...
...
@@ -187,7 +206,6 @@ static void scclDebugInit() {
}
__atomic_store_n
(
&
scclDebugLevel
,
tempScclDebugLevel
,
__ATOMIC_RELEASE
);
__atomic_store_n
(
&
scclDebugPos
,
tempScclDebugPos
,
__ATOMIC_RELEASE
);
pthread_mutex_unlock
(
&
scclDebugLock
);
}
...
...
@@ -195,7 +213,7 @@ static void scclDebugInit() {
////////////////////////////// 打印DEBUG信息 //////////////////////////////
template
<
scclDebugLogLevel_t
level
>
void
scclDebugLog
(
scclDebugLog
Po
s_t
pos_flags
,
const
char
*
filepath
,
int
line
,
const
char
*
fmt
,
...)
{
void
scclDebugLog
(
scclDebugLog
SubSy
s_t
pos_flags
,
const
char
*
filepath
,
int
line
,
const
char
*
fmt
,
...)
{
if
(
__atomic_load_n
(
&
scclDebugLevel
,
__ATOMIC_ACQUIRE
)
==
-
1
)
scclDebugInit
();
...
...
@@ -204,7 +222,7 @@ void scclDebugLog(scclDebugLogPos_t pos_flags, const char* filepath, int line, c
// 检查调试级别和位置标志
bool
isDebugLevelSufficient
=
(
scclDebugLevel
>=
level
);
bool
isDebugPositionMatch
=
(
scclDebugPos
==
SCCL_LOG_CODEALL
||
scclDebugPos
==
pos_flags
)
;
bool
isDebugPositionMatch
=
(
pos_flags
&
scclDebugMask
)
!=
0
;
// 如果调试级别不足或位置标志不匹配,则不执行后续操作
if
(
!
isDebugLevelSufficient
||
!
isDebugPositionMatch
)
{
return
;
...
...
src/utils/align.h
View file @
a4ac3320
#pragma once
namespace
sccl
{
#define DIVUP(x, y) (((x) + (y) - 1) / (y))
#define ROUNDUP(x, y) (DIVUP((x), (y)) * (y))
...
...
@@ -67,3 +68,5 @@ template <typename X, typename Z = decltype(X() + int())>
__host__
__device__
constexpr
Z
alignUp
(
X
x
,
int
a
)
{
return
(
x
+
a
-
1
)
&
Z
(
-
a
);
}
}
// namespace sccl
src/utils/alloc.h
View file @
a4ac3320
...
...
@@ -4,11 +4,11 @@
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include "align.h"
#include "check.h"
#include "align.h"
#include "asm_ops.h"
namespace
sccl
{
namespace
alloc
{
template
<
typename
T
>
...
...
@@ -47,7 +47,30 @@ inline scclResult_t scclHipHostFree(void* ptr) {
return
scclSuccess
;
}
/**
* @brief 分配调试内存
*
* 为类型T分配指定数量的元素内存,并记录调试信息。
*
* @param[out] ptr 指向分配内存的指针的指针
* @param[in] nelem 要分配的元素数量
* @param[in] filefunc 调用位置的文件/函数信息
* @param[in] line 调用位置的行号
*
* @return scclResult_t 返回操作结果,成功返回scclSuccess,失败返回scclSystemError
*/
template
<
typename
T
>
scclResult_t
scclMallocDebug
(
T
**
ptr
,
size_t
nelem
,
const
char
*
filefunc
,
int
line
)
{
void
*
p
=
malloc
(
nelem
*
sizeof
(
T
));
if
(
p
==
NULL
)
{
WARN
(
"Failed to malloc %ld bytes"
,
nelem
*
sizeof
(
T
));
return
scclSystemError
;
}
INFO
(
SCCL_LOG_ALLOC
,
"%s:%d malloc Size %ld pointer %p"
,
filefunc
,
line
,
nelem
*
sizeof
(
T
),
p
);
*
ptr
=
(
T
*
)
p
;
return
scclSuccess
;
}
/**
* @brief 分配并清零指定数量的元素内存(调试版本)
*
...
...
@@ -60,6 +83,7 @@ template <typename T>
*
* @note 此函数会记录内存分配日志,并在失败时返回错误
*/
template
<
typename
T
>
scclResult_t
scclCallocDebug
(
T
**
ptr
,
size_t
nelem
,
const
char
*
filefunc
,
int
line
)
{
void
*
p
=
malloc
(
nelem
*
sizeof
(
T
));
if
(
p
==
NULL
)
{
...
...
@@ -197,8 +221,8 @@ scclResult_t scclHipCallocDebug(const char* filefunc, int line, T** ptr, size_t
int
dev
;
HIPCHECK
(
hipGetDevice
(
&
dev
));
if
(
dev
<
MAX_ALLOC_TRACK_NGPU
)
{
__atomic_fetch_ad
d
(
&
allocTracker
[
dev
].
totalAlloc
,
1
,
__ATOMIC_RELAXED
);
__atomic_fetch_ad
d
(
&
allocTracker
[
dev
].
totalAllocSize
,
nelem
*
sizeof
(
T
)
,
__ATOMIC_RELAXED
);
asm_ops
::
add_ref_count_relaxe
d
(
&
allocTracker
[
dev
].
totalAlloc
,
1
);
asm_ops
::
add_ref_count_relaxe
d
(
&
allocTracker
[
dev
].
totalAllocSize
,
nelem
*
sizeof
(
T
));
}
finish:
HIPCHECK
(
hipThreadExchangeStreamCaptureMode
(
&
mode
));
...
...
@@ -244,8 +268,8 @@ scclResult_t scclHipCallocAsyncDebug(const char* filefunc, int line, T** ptr, si
int
dev
;
HIPCHECK
(
hipGetDevice
(
&
dev
));
if
(
dev
<
MAX_ALLOC_TRACK_NGPU
)
{
__atomic_fetch_ad
d
(
&
allocTracker
[
dev
].
totalAlloc
,
1
,
__ATOMIC_RELAXED
);
__atomic_fetch_ad
d
(
&
allocTracker
[
dev
].
totalAllocSize
,
nelem
*
sizeof
(
T
)
,
__ATOMIC_RELAXED
);
asm_ops
::
add_ref_count_relaxe
d
(
&
allocTracker
[
dev
].
totalAlloc
,
1
);
asm_ops
::
add_ref_count_relaxe
d
(
&
allocTracker
[
dev
].
totalAllocSize
,
nelem
*
sizeof
(
T
));
}
finish:
HIPCHECK
(
hipThreadExchangeStreamCaptureMode
(
&
mode
));
...
...
@@ -366,6 +390,9 @@ inline scclResult_t scclIbMallocDebug(void** ptr, size_t size, const char* filef
// 定义宏 scclHipHostCalloc,用于调试版本的主机端内存分配,自动添加文件名和行号信息
#define scclHipHostCalloc(...) alloc::scclHipHostCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
// 定义宏 scclCalloc,用于调试版本的常规内存分配,自动添加文件名和行号信息
#define scclMalloc(...) alloc::scclMallocDebug(__VA_ARGS__, __FILE__, __LINE__)
// 定义宏 scclCalloc,用于调试版本的常规内存分配,自动添加文件名和行号信息
#define scclCalloc(...) alloc::scclCallocDebug(__VA_ARGS__, __FILE__, __LINE__)
...
...
src/utils/archinfo.c
c
→
src/utils/archinfo.c
pp
View file @
a4ac3320
File moved
src/utils/archinfo.h
View file @
a4ac3320
#ifndef ARCHINFO_H_
#define ARCHINFO_H_
#pragma once
#include <string.h>
...
...
@@ -25,5 +24,3 @@ double GetDeviceWallClockRateInKhz(int deviceId);
// 判断指定的架构名称是否与目标架构匹配
bool
IsArchMatch
(
char
const
*
arch
,
char
const
*
target
);
}
// namespace sccl
#endif // ARCHINFO_H_
src/utils/asm_ops.h
0 → 100644
View file @
a4ac3320
#pragma once
#include <sys/mman.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include "check.h"
namespace
sccl
{
namespace
asm_ops
{
/*
标志名称描述与用途适用场景
__ATOMIC_RELAXED:最弱的内存顺序,无同步约束,仅保证原子性,适用于无数据依赖的场景
__ATOMIC_ACQUIRE:确保后续操作读取的共享数据可见,用于同步读取操作
__ATOMIC_RELEASE:确保当前操作对共享数据的修改对后续操作可见,用于同步
__ATOMIC_ACQ_REL:同时具备ACQUIRE和RELEASE语义,用于读写同步。确保在该原子操作之前的所有操作对其他线程可见,同时确保在该原子操作之后的所有操作对其他线程可见
__ATOMIC_SEQ_CST:顺序一致性约束,确保所有线程的操作按全局顺序执行
*/
/**
* 以宽松内存序对引用计数进行原子加1操作
* @param refs 指向引用计数的指针
*/
template
<
typename
Int
>
__host__
__device__
__forceinline__
void
add_ref_count_increment_relaxed
(
Int
*
refs
)
{
__atomic_fetch_add
(
refs
,
1
,
__ATOMIC_RELAXED
);
}
/**
* 以顺序一致性内存顺序对引用计数进行原子加1操作
* @param refs 指向引用计数的指针
*/
template
<
typename
Int
>
__host__
__device__
__forceinline__
void
add_ref_count_increment_seq_cst
(
Int
*
refs
)
{
__atomic_fetch_add
(
refs
,
1
,
__ATOMIC_SEQ_CST
);
}
/**
* 以宽松内存序原子地增加引用计数
* @param refs 指向引用计数变量的指针
* @param nbytes 要增加的字节数
*/
template
<
typename
Int
>
__host__
__device__
__forceinline__
void
add_ref_count_relaxed
(
Int
*
refs
,
int
nbytes
)
{
__atomic_fetch_add
(
refs
,
nbytes
,
__ATOMIC_RELAXED
);
}
/**
* 原子地减少引用计数并获取修改后的值(使用获取-释放内存序)
* @param refs 指向引用计数的指针
* @return 减少后的引用计数值
*/
template
<
typename
Int
>
__host__
__device__
__forceinline__
Int
sub_ref_count_decrement_acq_rel
(
Int
*
refs
)
{
return
__atomic_sub_fetch
(
refs
,
1
,
__ATOMIC_ACQ_REL
);
}
////////////////////////////////////////////////////////////////////////////////////////////////
/*出发VMFault异常*/
__device__
__forceinline__
void
trap
()
{
// asm("trap;");
__builtin_trap
();
}
__device__
__forceinline__
void
memory_fence
()
{
// __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "");
__threadfence_system
();
}
__device__
__forceinline__
void
memory_fence_gpu
()
{
// __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "agent");
__threadfence
();
}
__device__
__forceinline__
void
memory_fence_cta
()
{
// __builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");
__threadfence_block
();
}
template
<
typename
Int
>
__host__
__device__
__forceinline__
void
st_relaxed_sys_global
(
Int
*
ptr
,
Int
val
)
{
__atomic_store_n
(
ptr
,
val
,
__ATOMIC_RELAXED
);
// asm volatile("st.relaxed.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
template
<
typename
Int
>
__host__
__device__
__forceinline__
void
st_release_sys_global
(
Int
*
ptr
,
Int
val
)
{
__atomic_store_n
(
ptr
,
val
,
__ATOMIC_RELEASE
);
// asm volatile("st.release.sys.global.s32 [%0], %1;"::"l"(ptr), "r"(val) : "memory");
}
template
<
typename
Int
>
__host__
__device__
__forceinline__
Int
ld_acquire_sys_global
(
const
Int
*
ptr
)
{
Int
ret
;
ret
=
__atomic_load_n
(
ptr
,
__ATOMIC_ACQUIRE
);
// asm volatile("ld.acquire.sys.global.s32 %0, [%1];" : "=r"(ret) : "l"(ptr));
return
ret
;
}
}
// namespace asm_ops
}
// namespace sccl
src/utils/thread_pool.cpp
0 → 100644
View file @
a4ac3320
#include "thread_pool.h"
namespace
sccl
{
ThreadPool
::
ThreadPool
(
size_t
threads_num
)
:
stop
(
false
)
{
pthread_mutex_init
(
&
queue_mutex
,
nullptr
);
pthread_cond_init
(
&
condition
,
nullptr
);
for
(
size_t
i
=
0
;
i
<
threads_num
;
++
i
)
{
pthread_t
worker
;
pthread_create
(
&
worker
,
nullptr
,
ThreadPool
::
run
,
this
);
workers
.
push_back
(
worker
);
}
}
ThreadPool
::~
ThreadPool
()
{
{
pthread_mutex_lock
(
&
queue_mutex
);
stop
=
true
;
pthread_mutex_unlock
(
&
queue_mutex
);
pthread_cond_broadcast
(
&
condition
);
}
for
(
size_t
i
=
0
;
i
<
workers
.
size
();
++
i
)
{
pthread_join
(
workers
[
i
],
nullptr
);
}
pthread_mutex_destroy
(
&
queue_mutex
);
pthread_cond_destroy
(
&
condition
);
}
/**
* @brief 线程池中工作线程的执行函数
*
* 该函数作为线程池中每个工作线程的入口点,不断从任务队列中获取并执行任务。
* 使用互斥锁和条件变量实现线程安全的任务队列访问。
* 当线程池停止且任务队列为空时,线程退出。
*
* @param arg 指向ThreadPool实例的指针
* @return void* 总是返回nullptr
*/
void
*
ThreadPool
::
run
(
void
*
arg
)
{
ThreadPool
*
pool
=
static_cast
<
ThreadPool
*>
(
arg
);
while
(
true
)
{
std
::
function
<
void
()
>
task
;
{
pthread_mutex_lock
(
&
pool
->
queue_mutex
);
while
(
pool
->
tasks
.
empty
()
&&
!
pool
->
stop
)
{
pthread_cond_wait
(
&
pool
->
condition
,
&
pool
->
queue_mutex
);
}
if
(
pool
->
stop
&&
pool
->
tasks
.
empty
())
{
pthread_mutex_unlock
(
&
pool
->
queue_mutex
);
return
nullptr
;
}
task
=
pool
->
tasks
.
front
();
pool
->
tasks
.
pop
();
pthread_mutex_unlock
(
&
pool
->
queue_mutex
);
}
task
();
// 执行任务
{
pthread_mutex_lock
(
&
pool
->
queue_mutex
);
pool
->
active_tasks
--
;
// 任务完成减少活动任务计数
pthread_mutex_unlock
(
&
pool
->
queue_mutex
);
}
}
}
/**
* 检查线程池中所有任务是否已完成
*
* @return 如果活动任务数为0且任务队列为空则返回true,否则返回false
* @note 此操作是线程安全的,通过互斥锁保护共享数据
*/
bool
ThreadPool
::
allTasksCompleted
()
{
pthread_mutex_lock
(
&
queue_mutex
);
bool
completed
=
(
active_tasks
==
0
)
&&
tasks
.
empty
();
pthread_mutex_unlock
(
&
queue_mutex
);
return
completed
;
}
}
// namespace sccl
src/utils/thread_pool.h
0 → 100644
View file @
a4ac3320
#pragma once
#include <iostream>
#include <vector>
#include <queue>
#include <pthread.h>
#include <functional>
#include <future>
#include <memory>
namespace
sccl
{
class
ThreadPool
{
public:
ThreadPool
(
size_t
);
~
ThreadPool
();
// 将任务加入线程池队列并返回关联的future
template
<
class
F
,
class
...
Args
>
auto
enqueue
(
F
&&
f
,
Args
&&
...
args
)
->
std
::
future
<
typename
std
::
result_of
<
F
(
Args
...)
>::
type
>
{
using
return_type
=
typename
std
::
result_of
<
F
(
Args
...)
>::
type
;
auto
task
=
std
::
make_shared
<
std
::
packaged_task
<
return_type
()
>>
(
std
::
bind
(
std
::
forward
<
F
>
(
f
),
std
::
forward
<
Args
>
(
args
)...));
std
::
future
<
return_type
>
res
=
task
->
get_future
();
{
pthread_mutex_lock
(
&
queue_mutex
);
tasks
.
push
([
task
]()
{
(
*
task
)();
});
active_tasks
++
;
// 新任务增加活动任务计数
pthread_mutex_unlock
(
&
queue_mutex
);
pthread_cond_signal
(
&
condition
);
}
return
res
;
}
// 检查是否所有任务都已完成
bool
allTasksCompleted
();
private:
std
::
vector
<
pthread_t
>
workers
;
// 工作线程列表
std
::
queue
<
std
::
function
<
void
()
>>
tasks
;
// 任务队列
pthread_mutex_t
queue_mutex
;
// 保护任务队列的互斥锁
pthread_cond_t
condition
;
// 用于线程间通信的条件变量
bool
stop
;
// 标志位,指示线程池是否应该停止
int
active_tasks
;
// 追踪活动任务的数量
static
void
*
run
(
void
*
arg
);
};
}
// namespace sccl
src/utils/utils.c
c
→
src/utils/utils.c
pp
View file @
a4ac3320
...
...
@@ -16,13 +16,13 @@ namespace sccl {
// // Get current Compute Capability
// int scclCudaCompCap() {
// int
cuda
Dev;
// if(cudaGetDevice(&
cuda
Dev) != cudaSuccess)
// int
hip
Dev;
// if(cudaGetDevice(&
hip
Dev) != cudaSuccess)
// return 0;
// int ccMajor, ccMinor;
// if(cudaDeviceGetAttribute(&ccMajor, cudaDevAttrComputeCapabilityMajor,
cuda
Dev) != cudaSuccess)
// if(cudaDeviceGetAttribute(&ccMajor, cudaDevAttrComputeCapabilityMajor,
hip
Dev) != cudaSuccess)
// return 0;
// if(cudaDeviceGetAttribute(&ccMinor, cudaDevAttrComputeCapabilityMinor,
cuda
Dev) != cudaSuccess)
// if(cudaDeviceGetAttribute(&ccMinor, cudaDevAttrComputeCapabilityMinor,
hip
Dev) != cudaSuccess)
// return 0;
// return ccMajor * 10 + ccMinor;
// }
...
...
@@ -49,13 +49,13 @@ namespace sccl {
// return scclSuccess;
// }
// // Convert a logical
cuda
Dev index to the NVML device minor number
// scclResult_t getBusId(int
cuda
Dev, int64_t* busId) {
// // Convert a logical
hip
Dev index to the NVML device minor number
// scclResult_t getBusId(int
hip
Dev, int64_t* busId) {
// // On most systems, the PCI bus ID comes back as in the 0000:00:00.0
// // format. Still need to allocate proper space in case PCI domain goes
// // higher.
// char busIdStr[] = "00000000:00:00.0";
// CUDACHECK(cudaDeviceGetPCIBusId(busIdStr, sizeof(busIdStr),
cuda
Dev));
// CUDACHECK(cudaDeviceGetPCIBusId(busIdStr, sizeof(busIdStr),
hip
Dev));
// NCCLCHECK(busIdToInt64(busIdStr, busId));
// return scclSuccess;
// }
...
...
src/utils/utils.h
View file @
a4ac3320
...
...
@@ -20,7 +20,7 @@ namespace sccl {
// scclResult_t int64ToBusId(int64_t id, char* busId);
// scclResult_t busIdToInt64(const char* busId, int64_t* id);
// ncclResult_t getBusId(int
cuda
Dev, int64_t* busId);
// ncclResult_t getBusId(int
hip
Dev, int64_t* busId);
// ncclResult_t getHostName(char* hostname, int maxlen, const char delim);
// uint64_t getHash(const char* string, int n);
...
...
Prev
1
…
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment