#include "global_exchange.h" #include "utils/fmoe_utils.h" #include #ifdef FMOE_USE_NCCL #include void fmoe_cuda_expert_exchange_impl( const long* local_expert_count, long* global_expert_count, int n_expert, int world_size, CudaStreamManager* smgr) { NCCL_SAFE_CALL(ncclGroupStart()); for (int i = 0; i < world_size; ++i) { NCCL_SAFE_CALL(ncclSend( local_expert_count + n_expert * i, n_expert, ncclInt64, i, smgr->ncclcomm, smgr->stream(0))); NCCL_SAFE_CALL(ncclRecv( global_expert_count + n_expert * i, n_expert, ncclInt64, i, smgr->ncclcomm, smgr->stream(0))); } NCCL_SAFE_CALL(ncclGroupEnd()); smgr->sync(1); } torch::Tensor _expert_exchange( torch::Tensor local_expert_count, long n_expert, long n_workers) { auto global_expert_count = torch::empty_like(local_expert_count); auto smgr = getCudaStreamManager(local_expert_count.device().index()); fmoe_cuda_expert_exchange_impl( local_expert_count.data_ptr(), global_expert_count.data_ptr(), n_expert, n_workers, smgr); return global_expert_count; } torch::Tensor _global_scatter( torch::Tensor input_buf, torch::Tensor local_expert_count, torch::Tensor global_expert_count, long batch_size, long n_workers) { CHECK_INPUT(input_buf); auto n_expert = local_expert_count.size(0) / n_workers; auto in_feat = input_buf.size(1); auto global_input_buf = input_buf.new_empty({batch_size, in_feat}); auto smgr = getCudaStreamManager(input_buf.device().index()); AT_DISPATCH_FLOATING_TYPES_AND_HALF(input_buf.scalar_type(), "fmoe_cuda_global_scatter", ([&] { fmoe_cuda_global_scatter_impl( input_buf.data_ptr(), local_expert_count.data_ptr(), global_expert_count.data_ptr(), global_input_buf.data_ptr(), in_feat, n_expert, n_workers, smgr ); })); return global_input_buf; } torch::Tensor _global_gather( torch::Tensor output_buf, torch::Tensor local_expert_count, torch::Tensor global_expert_count, long batch_size, long n_workers) { CHECK_INPUT(output_buf); auto n_expert = local_expert_count.size(0) / n_workers; auto out_feat = output_buf.size(1); auto local_output_buf = output_buf.new_empty({batch_size, out_feat}); auto smgr = getCudaStreamManager(output_buf.device().index()); AT_DISPATCH_FLOATING_TYPES_AND_HALF(output_buf.scalar_type(), "fmoe_cuda_global_gather", ([&] { fmoe_cuda_global_gather_impl( output_buf.data_ptr(), local_expert_count.data_ptr(), global_expert_count.data_ptr(), local_output_buf.data_ptr(), out_feat, n_expert, n_workers, smgr ); })); return local_output_buf; } #include class HackNCCLGroup: public c10d::ProcessGroupNCCL { public: ncclComm_t getcomm(at::Device dev) { ncclUniqueId ncclID; int rank = getRank(); if (rank == 0) { ncclGetUniqueId(&ncclID); } #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \ (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 8)) broadcastUniqueNCCLID(&ncclID, c10d::OpType::SEND, "fastmoe_nccl_comm", rank); #else broadcastUniqueNCCLID(&ncclID); #endif ncclComm_t comm; NCCL_SAFE_CALL(ncclCommInitRank(&comm, getSize(), ncclID, rank)); return comm; } }; void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) { auto smgr = getCudaStreamManager(t.device().index()); if (smgr->ncclgood) { return; } HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p; smgr->ncclcomm = h->getcomm(t.device()); if (smgr->ncclcomm != 0) { smgr->ncclgood = 1; } else { std::cerr << "Nccl initialization failed\n"; } } #endif // FMOE_USE_NCCL