Unverified Commit 8311f083 authored by Li, Jiang's avatar Li, Jiang Committed by GitHub
Browse files

[Bugfix][CPU] Fix thread num for shared memory communication (#33317)


Signed-off-by: default avatarjiang1.li <jiang1.li@intel.com>
Signed-off-by: default avatarLi, Jiang <bigpyj64@gmail.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 40c35038
...@@ -237,10 +237,10 @@ struct ThreadSHMContext { ...@@ -237,10 +237,10 @@ struct ThreadSHMContext {
class SHMManager { class SHMManager {
public: public:
explicit SHMManager(const std::string& name, const int rank, explicit SHMManager(const std::string& name, const int rank,
const int group_size) const int group_size, const int thread_num)
: _rank(rank), : _rank(rank),
_group_size(group_size), _group_size(group_size),
_thread_num(omp_get_max_threads()), _thread_num(thread_num),
_shm_names({""}), _shm_names({""}),
_shared_mem_ptrs({nullptr}), _shared_mem_ptrs({nullptr}),
_shm_ctx(nullptr) { _shm_ctx(nullptr) {
...@@ -282,11 +282,11 @@ class SHMManager { ...@@ -282,11 +282,11 @@ class SHMManager {
} }
static int64_t create_singleton_instance(const std::string& name, static int64_t create_singleton_instance(const std::string& name,
const int group_size, const int group_size, const int rank,
const int rank) { const int thread_num) {
std::lock_guard<std::mutex> guard(SingletonInstancesLock); std::lock_guard<std::mutex> guard(SingletonInstancesLock);
SingletonInstances.emplace_back( SingletonInstances.emplace_back(
std::make_unique<SHMManager>(name, rank, group_size)); std::make_unique<SHMManager>(name, rank, group_size, thread_num));
return static_cast<int64_t>(SingletonInstances.size() - 1); return static_cast<int64_t>(SingletonInstances.size() - 1);
} }
...@@ -854,8 +854,9 @@ std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) { ...@@ -854,8 +854,9 @@ std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src) {
} }
int64_t init_shm_manager(const std::string& name, const int64_t group_size, int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank) { const int64_t rank, const int64_t thread_num) {
return SHMManager::create_singleton_instance(name, group_size, rank); return SHMManager::create_singleton_instance(name, group_size, rank,
thread_num);
} }
std::string join_shm_manager(int64_t handle, const std::string& name) { std::string join_shm_manager(int64_t handle, const std::string& name) {
......
...@@ -35,7 +35,7 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, ...@@ -35,7 +35,7 @@ void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query,
torch::Tensor& block_tables, torch::Tensor& seq_lens); torch::Tensor& block_tables, torch::Tensor& seq_lens);
int64_t init_shm_manager(const std::string& name, const int64_t group_size, int64_t init_shm_manager(const std::string& name, const int64_t group_size,
const int64_t rank); const int64_t rank, const int64_t thread_num);
std::string join_shm_manager(int64_t handle, const std::string& name); std::string join_shm_manager(int64_t handle, const std::string& name);
...@@ -232,8 +232,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -232,8 +232,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// SHM CCL // SHM CCL
#if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__)) #if defined(__AVX512F__) || (defined(__aarch64__) && !defined(__APPLE__))
ops.def("init_shm_manager(str name, int group_size, int rank) -> int", ops.def(
&init_shm_manager); "init_shm_manager(str name, int group_size, int rank, int thread_num) -> "
"int",
&init_shm_manager);
ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager); ops.def("join_shm_manager(int handle, str name) -> str", &join_shm_manager);
ops.def("shm_allreduce(int handle, Tensor! data) -> ()"); ops.def("shm_allreduce(int handle, Tensor! data) -> ()");
ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce); ops.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
......
...@@ -205,10 +205,22 @@ class _CPUSHMDistributed: ...@@ -205,10 +205,22 @@ class _CPUSHMDistributed:
self.handle = self._init_cpu_shm() self.handle = self._init_cpu_shm()
def _init_cpu_shm(self) -> int: def _init_cpu_shm(self) -> int:
thread_num_tensor = torch.tensor(
[torch.get_num_threads()],
dtype=torch.int64,
)
torch.distributed.all_reduce(
thread_num_tensor,
op=torch.distributed.ReduceOp.MIN,
group=self.communicator.device_group,
)
thread_num = thread_num_tensor.item()
handle = torch.ops._C.init_shm_manager( handle = torch.ops._C.init_shm_manager(
self.group_name, self.group_name,
self.communicator.world_size, self.communicator.world_size,
self.communicator.rank, self.communicator.rank,
thread_num,
) )
torch.distributed.barrier(self.communicator.device_group) torch.distributed.barrier(self.communicator.device_group)
torch.ops._C.join_shm_manager( torch.ops._C.join_shm_manager(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment