Unverified Commit 8311f083 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix][CPU] Fix thread num for shared memory communication (#33317)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Signed-off-by: default avatarLi, Jiang <bigpyj64@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 40c35038
......@@ -237,10 +237,10 @@ struct ThreadSHMContext {
class SHMManager {
public:
explicit SHMManager(const std::string& name, const int rank,
const int group_size)
const int group_size, const int thread_num)
: _rank(rank),
_group_size(group_size),
_thread_num(omp_get_max_threads()),
_thread_num(thread_num),
_shm_names({""}),
_shared_mem_ptrs({nullptr}),
_shm_ctx(nullptr) {
......@@ -282,11 +282,11 @@ class SHMManager {
}
static int64_t create_singleton_instance(const std::string& name,
const int group_size,
const int rank) {
const int group_size, const int rank,
const int thread_num) {
std::lock_guard<std::mutex> guard(SingletonInstancesLock);
SingletonInstances.emplace_back(
std::make_unique<SHMManager>(name, rank, group_size));
std::make_unique<SHMManager>(name, rank, group_size, thread_num));
return static_cast<int64_t>(SingletonInstances.size() - 1);
}
......@@ -854,8 +854,9 @@ std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
}
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank) {
return SHMManager::create_singleton_instance(name, group_size, rank);
const int64_t rank, const int64_t thread_num) {
return SHMManager::create_singleton_instance(name, group_size, rank,
thread_num);
}
std::string join_shm_manager(int64_t handle, const std::string& name) {
......
......@@ -35,7 +35,7 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& block_tables, torch::Tensor& seq_lens);
int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank);
const int64_t rank, const int64_t thread_num);
std::string join_shm_manager(int64_t handle, const std::string& name);
......@@ -232,8 +232,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// SHM CCL
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
ops.def("init_shm_manager(str name, int group_size, int rank) -> int",
&init_shm_manager);
ops.def(
"init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
"int",
&init_shm_manager);
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
......
......@@ -205,10 +205,22 @@ class _CPUSHMDistributed:
self.handle = self._init_cpu_shm()
def _init_cpu_shm(self) -> int:
thread_num_tensor = torch.tensor(
[torch.get_num_threads()],
dtype=torch.int64,
)
torch.distributed.all_reduce(
thread_num_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.communicator.device_group,
)
thread_num = thread_num_tensor.item()
handle = torch.ops._C.init_shm_manager(
self.group_name,
self.communicator.world_size,
self.communicator.rank,
thread_num,
)
torch.distributed.barrier(self.communicator.device_group)
torch.ops._C.join_shm_manager(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment