"src/vscode:/vscode.git/clone" did not exist on "f25b1a064d82e159b24d5f0a6965cb6d4401b913"
Commit 43af1522 authored by Rick Ho's avatar Rick Ho
Browse files

manually broadcast nccl unique id

parent 4d48209d
...@@ -10,3 +10,4 @@ a.out ...@@ -10,3 +10,4 @@ a.out
build build
*swp *swp
logs logs
dist
...@@ -117,15 +117,29 @@ public: ...@@ -117,15 +117,29 @@ public:
ncclComm_t getcomm(at::Device dev) { ncclComm_t getcomm(at::Device dev) {
auto key = std::to_string(dev.index()); auto key = std::to_string(dev.index());
#ifdef ENABLE_NCCL_P2P_SUPPORT #ifdef ENABLE_NCCL_P2P_SUPPORT
auto v = getNCCLComm(key, {dev}, c10d::OpType::ALLTOALL); ncclUniqueId ncclID;
int rank = getRank();
if (rank == 0) {
ncclGetUniqueId(&ncclID);
}
broadcastUniqueNCCLID(&ncclID,
c10d::OpType::SEND,
"fastmoe_nccl_comm",
rank);
ncclComm_t comm;
ncclCommInitRank(&comm, getSize(), ncclID, rank);
return comm;
#else #else
auto v = getNCCLComm(key, {dev}); auto v = getNCCLComm(key, {dev});
#endif
if (v.size() == 0) { if (v.size() == 0) {
std::cerr << "PyTorch has nothing\n"; std::cerr << "PyTorch has nothing\n";
return 0; return 0;
} }
int count;
ncclCommCount(v[0]->getNcclComm(), &count);
std::cerr << "PyTorch has " << v.size() << " comms, comm 0 size " << count << "\n";
return v[0]->getNcclComm(); return v[0]->getNcclComm();
#endif
} }
}; };
......
import torch
import torch.nn as nn
from fmoe import FMoETransformerMLP from fmoe import FMoETransformerMLP
from fmoe.gates import NaiveGate from fmoe.gates import NaiveGate
from moe import BruteForceMoELinear from moe import BruteForceMoELinear
import torch
import torch.nn as nn
import time import time
import sys import sys
import os import os
......
...@@ -25,7 +25,10 @@ PYTHON_VERSION=$($PYTHON_EXEC --version) ...@@ -25,7 +25,10 @@ PYTHON_VERSION=$($PYTHON_EXEC --version)
PYTHON_REVISION=${PYTHON_VERSION:7:3} PYTHON_REVISION=${PYTHON_VERSION:7:3}
SCRIPT_PATH=$(dirname $(dirname $(realpath $0))) SCRIPT_PATH=$(dirname $(dirname $(realpath $0)))
source ~/scripts/torch.env
export PYTHONPATH=$SCRIPT_PATH:$SCRIPT_PATH/build/lib.linux-x86_64-$PYTHON_REVISION:$PYTHONPATH export PYTHONPATH=$SCRIPT_PATH:$SCRIPT_PATH/build/lib.linux-x86_64-$PYTHON_REVISION:$PYTHONPATH
export LD_LIBRARY_PATH=/home/laekov/.local/lib/python$PYTHON_REVISION/site-packages/torch/lib:$LD_LIBRARY_PATH
exec $PYTHON_EXEC $@ 2>logs/$RANK.log core0=$(expr $OMPI_COMM_WORLD_LOCAL_RANK \* 4)
cores=$core0-$(expr $core0 + 3)
exec numactl -C $cores $PYTHON_EXEC $@ 2>logs/$RANK.log
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment