Unverified Commit 1e8e455a authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #157 from laekov/pytorch2-compat

Fix ProcessGroupNCCL mismatch in pytorch2
parents 9fe65ec0 267eb9cc
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \ #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)) (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else #else
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
...@@ -26,7 +27,12 @@ torch::Tensor _global_gather( ...@@ -26,7 +27,12 @@ torch::Tensor _global_gather(
torch::Tensor local_expert_count, torch::Tensor local_expert_count,
torch::Tensor global_expert_count, torch::Tensor global_expert_count,
long batch_size, long n_workers); long batch_size, long n_workers);
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t);
#else
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t); void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t);
#endif // TORCH_VERSION
#endif // FMOE_USE_NCCL #endif // FMOE_USE_NCCL
// local_exchange // local_exchange
......
...@@ -100,6 +100,7 @@ torch::Tensor _global_gather( ...@@ -100,6 +100,7 @@ torch::Tensor _global_gather(
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \ #if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR > 1 || \
(TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13)) (TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR >= 13))
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp> #include <torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp>
#else #else
#include <c10d/ProcessGroupNCCL.hpp> #include <c10d/ProcessGroupNCCL.hpp>
...@@ -134,12 +135,21 @@ public: ...@@ -134,12 +135,21 @@ public:
} }
}; };
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
void _ensure_nccl(c10d::ProcessGroup& p, torch::Tensor t) {
#else
void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) { void _ensure_nccl(c10d::ProcessGroupNCCL& p, torch::Tensor t) {
#endif // TORCH_VERSION
auto smgr = getCudaStreamManager(t.device().index()); auto smgr = getCudaStreamManager(t.device().index());
if (smgr->ncclgood) { if (smgr->ncclgood) {
return; return;
} }
#if defined(TORCH_VERSION_MAJOR) && (TORCH_VERSION_MAJOR >= 2)
HackNCCLGroup* h = (HackNCCLGroup*)(void*)
(p.getBackend(c10d::ProcessGroup::NCCL).get());
#else
HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p; HackNCCLGroup* h = (HackNCCLGroup*)(void*)&p;
#endif // TORCH_VERSION
smgr->ncclcomm = h->getcomm(t.device()); smgr->ncclcomm = h->getcomm(t.device());
if (smgr->ncclcomm != 0) { if (smgr->ncclcomm != 0) {
smgr->ncclgood = 1; smgr->ncclgood = 1;
......
...@@ -41,7 +41,7 @@ else: ...@@ -41,7 +41,7 @@ else:
if __name__ == '__main__': if __name__ == '__main__':
setuptools.setup( setuptools.setup(
name='fastmoe', name='fastmoe',
version='1.0.0', version='1.0.1',
description='An efficient Mixture-of-Experts system for PyTorch', description='An efficient Mixture-of-Experts system for PyTorch',
author=', '.join(authors), author=', '.join(authors),
author_email='hja20@mails.tsinghua.edu.cn', author_email='hja20@mails.tsinghua.edu.cn',
......
import pytest
import os
import sys
import json
import math
import torch
import torch.distributed as dist
import torch.nn.functional as F
from fmoe.functions import ensure_comm
from test_ddp import _ensure_initialized, _run_distributed
@pytest.mark.parametrize("n", [1, 2])
def test_ensure(n):
_run_distributed('_test_ensure',
n, dict(),
script=__file__
)
def _test_ensure():
_ensure_initialized()
rank = torch.distributed.get_rank()
x = torch.rand(10).cuda()
ensure_comm(x, None)
if __name__ == '__main__':
if len(sys.argv) >= 3:
args = json.loads(sys.argv[2])
locals()[sys.argv[1]](**args)
else:
_ensure_initialized()
_test_ensure()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment