Unverified Commit f8eddcf9 authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

Add support for UB MNNVL (#1470)



* Add support for UB MNNVL
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Address review comments
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Fix lint
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Dlopen nvml lib since it comes with the cuda driver
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Add initial copyright date
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

---------
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>
parent e85d1806
...@@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None): ...@@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None):
help="Disable the comm+GEMM overlap.", help="Disable the comm+GEMM overlap.",
) )
parser.add_argument( parser.add_argument(
"--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." "--num-replicas",
type=int,
default=1,
help="Number of data-parallel model replicas per node.",
)
parser.add_argument(
"--use-global-replica-count",
action="store_true",
default=False,
help="Treat '--num-replicas' as the total number of replicas.",
) )
parser.add_argument( parser.add_argument(
"--tcp-init", "--tcp-init",
...@@ -173,13 +182,12 @@ def _train(opts): ...@@ -173,13 +182,12 @@ def _train(opts):
opts.tcp_init = True opts.tcp_init = True
opts.bind_to_device = True opts.bind_to_device = True
opts.bootstrap_backend = "mpi" opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ: else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE NUM_NODES = WORLD_SIZE // LOCAL_SIZE
# Initialize torch.distributed global process group and get DP/TP groups # Initialize torch.distributed global process group and get DP/TP groups
...@@ -214,90 +222,24 @@ def _train(opts): ...@@ -214,90 +222,24 @@ def _train(opts):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Figure out process groups for tensor- and data-parallelism (if any) total_replicas = (
if NUM_NODES > 1: opts.num_replicas if opts.use_global_replica_count else opts.num_replicas * NUM_NODES
# Create a list of world ranks on this node )
hostname = socket.gethostname() tp_size = WORLD_SIZE // total_replicas
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(hostnames, hostname)
unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
assert len(unique_hosts) == NUM_NODES
ranks_per_node_list = [[] for _ in range(NUM_NODES)]
self_node_idx = -1
for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname:
self_node_idx = node_idx
assert self_node_idx >= 0
self_node_ranks = ranks_per_node_list[self_node_idx]
if opts.num_replicas > 1:
# Split node ranks into multiple replicas
assert len(self_node_ranks) % opts.num_replicas == 0
tp_size = len(self_node_ranks) // opts.num_replicas
ranks_per_replica_list = []
for node_ranks in ranks_per_node_list:
for i in range(opts.num_replicas):
start = i * tp_size
end = start + tp_size
ranks_per_replica_list.append(node_ranks[start:end])
self_replica_idx = -1
for i, replica_ranks in enumerate(ranks_per_replica_list):
if WORLD_RANK in replica_ranks:
self_replica_idx = i
break
assert self_replica_idx >= 0
else: if total_replicas > 1:
# The entire node is the tensor-parallel group ranks_per_replica_list = [
ranks_per_replica_list = ranks_per_node_list [i * tp_size + t for t in range(tp_size)] for i in range(total_replicas)
self_replica_idx = self_node_idx ]
tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl")
ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32)
dp_group, _ = dist.new_subgroups_by_enumeration( dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
) )
else: else:
if opts.num_replicas > 1: dp_group = None
# Mixed data- and tensor-parallelism on a single node tp_group = nccl_world
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
ranks_per_replica_tensor = all_ranks.reshape(
(opts.num_replicas, LOCAL_SIZE // opts.num_replicas)
)
tp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.tolist(), backend="nccl"
)
dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else:
dp_group = None
tp_group = nccl_world
tp_rank = dist.get_rank(tp_group) tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
......
...@@ -180,15 +180,22 @@ def _main(opts): ...@@ -180,15 +180,22 @@ def _main(opts):
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True opts.tcp_init = True
opts.bootstrap_backend = "mpi" opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ: else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") result = subprocess.run(
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
assert LOCAL_SIZE <= torch.cuda.device_count() capture_output=True,
text=True,
shell=True,
)
if result.stdout == "0": # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE
assert LOCAL_SIZE <= torch.cuda.device_count()
# Fix clock speed # Fix clock speed
torch.cuda.set_device(LOCAL_RANK) torch.cuda.set_device(LOCAL_RANK)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import os import os
import sys import sys
import socket import socket
import subprocess
import argparse import argparse
import warnings import warnings
import pprint import pprint
...@@ -209,14 +210,21 @@ def _train(opts): ...@@ -209,14 +210,21 @@ def _train(opts):
opts.tcp_init = True opts.tcp_init = True
opts.bind_to_device = True opts.bind_to_device = True
opts.bootstrap_backend = "mpi" opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ: else:
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") result = subprocess.run(
assert LOCAL_SIZE == WORLD_SIZE "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
capture_output=True,
text=True,
shell=True,
)
if result.stdout == "0": # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE
def dist_print(msg, src=None, end="\n", debug=False, error=False): def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug: if debug and not opts.debug:
...@@ -227,7 +235,7 @@ def _train(opts): ...@@ -227,7 +235,7 @@ def _train(opts):
dist.barrier() dist.barrier()
# Set device and initialize RNG states # Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK) torch.cuda.set_device(LOCAL_RANK)
torch.manual_seed(opts.seed) torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed) torch.cuda.manual_seed(opts.seed)
...@@ -312,7 +320,7 @@ def _train(opts): ...@@ -312,7 +320,7 @@ def _train(opts):
return out return out
torch_rng_state = torch.get_rng_state() torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}")) cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph() test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph): with torch.cuda.graph(test_graph):
...@@ -329,7 +337,7 @@ def _train(opts): ...@@ -329,7 +337,7 @@ def _train(opts):
names.append(test_name + ".grad") names.append(test_name + ".grad")
torch.set_rng_state(torch_rng_state) torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}")) torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph() ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph): with torch.cuda.graph(ref_graph):
......
...@@ -78,6 +78,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -78,6 +78,7 @@ list(APPEND transformer_engine_SOURCES
util/cast.cu util/cast.cu
util/padding.cu util/padding.cu
util/cuda_driver.cpp util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp util/cuda_runtime.cpp
util/rtc.cpp util/rtc.cpp
swizzle/swizzle.cu swizzle/swizzle.cu
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include "common/util/cuda_driver.h" #include "common/util/cuda_driver.h"
#include "common/util/cuda_nvml.h"
#include "common/util/cuda_runtime.h" #include "common/util/cuda_runtime.h"
#include "common/util/logging.h" #include "common/util/logging.h"
#include "common/util/system.h" #include "common/util/system.h"
...@@ -29,7 +30,6 @@ ...@@ -29,7 +30,6 @@
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD; static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD;
static MPI_Comm EXT_COMM_INTRA; static MPI_Comm EXT_COMM_INTRA;
static MPI_Comm EXT_COMM_INTER;
#define UB_MPI_CHECK(expr) \ #define UB_MPI_CHECK(expr) \
do { \ do { \
...@@ -58,11 +58,20 @@ void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); } ...@@ -58,11 +58,20 @@ void ub_mpi_barrier(ExtComm comm) { UB_MPI_CHECK(MPI_Barrier(comm)); }
#else #else
#define EXT_COMM_WORLD "world" #define EXT_COMM_WORLD "world"
#define EXT_COMM_INTRA "intra" #define EXT_COMM_INTRA "intra"
#define EXT_COMM_INTER "inter"
#endif #endif
#define MULTICAST_GB_TOTAL 512 #define MULTICAST_GB_TOTAL 512
#if CUDART_VERSION < 12030
// MNNVL: FABRIC handle support lifted from CUDA 12.3
#define CU_MEM_HANDLE_TYPE_FABRIC ((CUmemAllocationHandleType)0x8ULL)
#define CU_IPC_HANDLE_SIZE 64
typedef struct CUmemFabricHandle_st {
unsigned char data[CU_IPC_HANDLE_SIZE];
} CUmemFabricHandle_v1;
typedef CUmemFabricHandle_v1 CUmemFabricHandle;
#endif
int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); } int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); }
#define IPCCHECK(cmd) \ #define IPCCHECK(cmd) \
...@@ -82,18 +91,43 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co ...@@ -82,18 +91,43 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co
} \ } \
} while (0); } while (0);
int pipe_rank(communicator *comm, int step) { bool has_mnnvl_fabric(int device_id) {
int mynode = comm->myrank / comm->nvsize; #if CUDA_VERSION < 12040
int mylocal = comm->nvrank; if (getenv("NVTE_UBDEBUG")) {
int numlocal = comm->nvsize; printf(
"TransformerEngine does not support multi-node NVLINK "
int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize; "since it was not built with CUDA version >= 12.4.\n");
int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal; }
int newnode = mynode; return false;
newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes; #else
int allnodes = comm->nranks / comm->nvsize; bool mnnvl_fabric_support = false;
newnode = (allnodes + (newnode % allnodes)) % allnodes; CUdevice dev;
return newnode * numlocal + newlocal; NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &dev, device_id);
int fabric_handle_supported = 0;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGetAttribute, &fabric_handle_supported,
CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev);
if (fabric_handle_supported) {
NVTE_CALL_CHECK_CUDA_NVML(nvmlInit_v2);
nvmlDevice_t local_device;
NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetHandleByIndex_v2, device_id, &local_device);
nvmlGpuFabricInfoV_t fabricInfo = {};
fabricInfo.version = nvmlGpuFabricInfo_v2;
fabricInfo.clusterUuid[0] = '\0';
NVTE_CALL_CHECK_CUDA_NVML(nvmlDeviceGetGpuFabricInfoV, local_device, &fabricInfo);
NVTE_CALL_CHECK_CUDA_NVML(nvmlShutdown);
if (fabricInfo.state >= NVML_GPU_FABRIC_STATE_COMPLETED && fabricInfo.clusterUuid[0] != '\0') {
mnnvl_fabric_support = true;
}
}
if (getenv("NVTE_UBDEBUG")) {
if (mnnvl_fabric_support) {
printf("MNNVL NVLINK is supported on this platform.\n");
} else {
printf("MNNVL NVLINK is not supported on this platform.\n");
}
}
return mnnvl_fabric_support;
#endif
} }
int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal, int create_communicator_grouped2(communicator **comm, int myrank, int numranks, int mylocal,
...@@ -122,10 +156,6 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, ...@@ -122,10 +156,6 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
(*comm)->use_ce = 0; (*comm)->use_ce = 0;
(*comm)->cga_size = 2; (*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0; for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0;
(*comm)->head = 0;
(*comm)->tail = 0;
(*comm)->active_nreqs = 0;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1;
int device_clock = 0; int device_clock = 0;
// 110 sec wait time by default // 110 sec wait time by default
...@@ -182,29 +212,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, ...@@ -182,29 +212,14 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
// ar2 has step equal to ar_nvsize // ar2 has step equal to ar_nvsize
int allnodes = numranks / numlocal; int allnodes = numranks / numlocal;
int nodeid = myrank / numlocal; int nodeid = myrank / numlocal;
int datanodes = allnodes / pipenodes / tensornodes;
int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes);
(*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus);
(*comm)->comm_inter = EXT_COMM_INTER;
(*comm)->first_node = nodeid - mynode;
(*comm)->num_nodes = numnodes; (*comm)->num_nodes = numnodes;
(*comm)->my_node = mynode; (*comm)->my_node = mynode;
(*comm)->num2_nodes = tensornodes;
(*comm)->my2_node = (mynode / datanodes) % tensornodes;
(*comm)->first2_node = mynode - (*comm)->my2_node * datanodes;
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
(*comm)->nblocks = 8;
(*comm)->alignblock = 1024 * 512;
(*comm)->minblock = 1024 * 2 * 1024;
(*comm)->asyncblocks = 16;
#define NBUF 2 #define NBUF 2
#if CUDART_VERSION >= 12010 #if CUDART_VERSION >= 12010
bool mnnvl_fabric = has_mnnvl_fabric(cur_dev);
if (!transformer_engine::getenv<bool>("UB_SKIPMC") && if (!transformer_engine::getenv<bool>("UB_SKIPMC") &&
transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) { transformer_engine::cuda::supports_multicast() && (*comm)->ar2_nvsize > 1) {
// multicast init only for TP ops (____2 operations) // multicast init only for TP ops (____2 operations)
...@@ -215,7 +230,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, ...@@ -215,7 +230,8 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
CUmulticastObjectProp mcProp = {}; CUmulticastObjectProp mcProp = {};
mcProp.numDevices = (*comm)->ar2_nvsize; mcProp.numDevices = (*comm)->ar2_nvsize;
mcProp.size = (*comm)->mc_maxsize; mcProp.size = (*comm)->mc_maxsize;
mcProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; mcProp.handleTypes =
mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
NVTE_CALL_CHECK_CUDA_DRIVER( NVTE_CALL_CHECK_CUDA_DRIVER(
cuMulticastGetGranularity, &gran, &mcProp, cuMulticastGetGranularity, &gran, &mcProp,
...@@ -223,46 +239,78 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, ...@@ -223,46 +239,78 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran; mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran;
mcProp.size = mc_maxsize; mcProp.size = mc_maxsize;
(*comm)->mc_maxsize = mc_maxsize; (*comm)->mc_maxsize = mc_maxsize;
if ((*comm)->ar2_nvrank == 0)
// Broadcast the a POSIX file descriptor from the local root rank to other local ranks.
// NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the
// file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use Unix domain sockets for the kernel to
// recreate the correct file descriptor on every receiving rank.
int fd;
volatile uint32_t abortFlag = 0;
IpcSocketHandle ipcSock = {0};
uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu;
ipcSocketResult_t ret = ipcSocketSuccess;
IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag));
(*comm)->_barrier((*comm)->comm_world);
if ((*comm)->ar2_nvrank == 0) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastCreate, &(*comm)->mc_handle, &mcProp); NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastCreate, &(*comm)->mc_handle, &mcProp);
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemExportToShareableHandle, reinterpret_cast<void *>(&fd), (*comm)->mc_handle,
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR),
(uint64_t)0);
for (int p = 1; p < (*comm)->ar2_nvsize; p++) { if (mnnvl_fabric) {
(*comm)->_barrier((*comm)->comm_intra); CUmemFabricHandle *exphndl =
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error); reinterpret_cast<CUmemFabricHandle *>(malloc(sizeof(CUmemFabricHandle)));
CUmemFabricHandle *tmphndl =
reinterpret_cast<CUmemFabricHandle *>(malloc(sizeof(CUmemFabricHandle)));
CUmemFabricHandle *exphndls;
NVTE_CHECK_CUDA(cudaMallocHost(&exphndls, (*comm)->nvsize * sizeof(CUmemFabricHandle)));
if ((*comm)->ar2_nvrank == 0)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, static_cast<void *>(tmphndl),
(*comm)->mc_handle, CU_MEM_HANDLE_TYPE_FABRIC, 0);
for (int grp = 0; grp < (*comm)->ar_nvsize;
grp++) { // we do N broadcasts for N TP groups in NVL domain
int root = grp * (*comm)->ar2_nvsize;
// It just needs to be a bcast but reuse existing allgather comm
(*comm)->_allgather(
reinterpret_cast<void *>(exphndls), (*comm)->nvsize * sizeof(CUmemFabricHandle),
reinterpret_cast<void *>(tmphndl), sizeof(CUmemFabricHandle), (*comm)->comm_intra);
//save data if brodcast was from rank 0 in our group
if ((*comm)->ar2_firstgpu == root)
memcpy(exphndl, exphndls + root, sizeof(CUmemFabricHandle));
} }
if ((*comm)->ar2_nvrank != 0)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &(*comm)->mc_handle,
reinterpret_cast<void *>(exphndl), CU_MEM_HANDLE_TYPE_FABRIC);
free(exphndl);
free(tmphndl);
NVTE_CHECK_CUDA(cudaFreeHost(exphndls));
} else { } else {
for (int p = 1; p < (*comm)->ar2_nvsize; p++) { // Broadcast the a POSIX file descriptor from the local root rank to other local ranks.
(*comm)->_barrier((*comm)->comm_intra); // NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the
if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error); // file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use Unix domain sockets for the kernel to
// recreate the correct file descriptor on every receiving rank.
int fd;
volatile uint32_t abortFlag = 0;
IpcSocketHandle ipcSock = {0};
uint64_t opId = 0xdeadcafeb000 + (*comm)->my_node + (*comm)->ar2_firstgpu;
ipcSocketResult_t ret = ipcSocketSuccess;
IPCCHECK(ipcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag));
(*comm)->_barrier((*comm)->comm_world);
if ((*comm)->ar2_nvrank == 0) {
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemExportToShareableHandle, reinterpret_cast<void *>(&fd), (*comm)->mc_handle,
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR),
(uint64_t)0);
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
(*comm)->_barrier((*comm)->comm_intra);
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error);
}
} else {
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
(*comm)->_barrier((*comm)->comm_intra);
if ((*comm)->ar2_nvrank == p) IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &fd), ret, error);
}
} }
}
error: error:
if ((*comm)->ar2_nvrank != 0) { if ((*comm)->ar2_nvrank != 0) {
NVTE_CALL_CHECK_CUDA_DRIVER( NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast<void *>(fd), cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast<void *>(fd),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR)); static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
}
IPCCHECK(ipcSocketClose(&ipcSock));
close(fd);
} }
IPCCHECK(ipcSocketClose(&ipcSock));
close(fd);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle, NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle,
(CUdeviceptr)(*comm)->mydev); (CUdeviceptr)(*comm)->mydev);
...@@ -327,12 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks, ...@@ -327,12 +375,11 @@ int create_communicator_grouped2(communicator **comm, int myrank, int numranks,
if (getenv("NVTE_UBDEBUG")) if (getenv("NVTE_UBDEBUG"))
printf( printf(
"%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP " "%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP x%d TPGROUP "
"%dx%d PIPE_ID %d/%d\n", "%dx%d\n",
myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node, myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes, (*comm)->ar_nvrank, (*comm)->my_node, (*comm)->ar2_nvrank, (*comm)->ar_nvsize,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id, (*comm)->num_nodes, (*comm)->ar2_nvsize);
pipegpus * pipenodes);
fflush(NULL); fflush(NULL);
return 0; return 0;
...@@ -361,43 +408,14 @@ int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipe ...@@ -361,43 +408,14 @@ int create_communicator_grouped2_mpi(communicator **comm, int pipegpus, int pipe
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank)); UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_WORLD, &myrank));
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_WORLD, &numranks));
// find intranode numbers and make internode communicator
char hostname[MPI_MAX_PROCESSOR_NAME];
int namelen;
UB_MPI_CHECK(MPI_Get_processor_name(hostname, &namelen));
char(*hostnames)[MPI_MAX_PROCESSOR_NAME] =
static_cast<char(*)[MPI_MAX_PROCESSOR_NAME]>(malloc(numranks * MPI_MAX_PROCESSOR_NAME));
strcpy(hostnames[myrank], hostname); // NOLINT(*)
for (int n = 0; n < numranks; n++)
UB_MPI_CHECK(MPI_Bcast(&(hostnames[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD));
qsort(hostnames, numranks, MPI_MAX_PROCESSOR_NAME, stringCmp);
int color = 0;
for (int n = 0; n < numranks; n++) {
if (n > 0 && strcmp(hostnames[n - 1], hostnames[n])) color++;
if (strcmp(hostname, hostnames[n]) == 0) break;
}
free(hostnames);
int mylocal, numlocal; int mylocal, numlocal;
UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, color, myrank, &EXT_COMM_INTRA)); UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, myrank / tensorgpus, myrank, &EXT_COMM_INTRA));
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal)); UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTRA, &mylocal));
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal)); UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTRA, &numlocal));
// find internode numbers and make internode communicator // find internode numbers and make internode communicator
NVTE_CHECK_CUDA(cudaFree(0)); NVTE_CHECK_CUDA(cudaFree(0));
int allnodes = numranks / numlocal;
int datanodes = allnodes / pipenodes / tensornodes;
// data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1
int datanodegroup_id = myrank / numlocal / datanodes;
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel
UB_MPI_CHECK(MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, myrank,
&EXT_COMM_INTER));
// different rails from same group are in different subcommunicators
int mynode, numnodes; int mynode, numnodes;
UB_MPI_CHECK(MPI_Comm_size(EXT_COMM_INTER, &numnodes));
UB_MPI_CHECK(MPI_Comm_rank(EXT_COMM_INTER, &mynode));
// finally call the abstracted constructor with MPI info // finally call the abstracted constructor with MPI info
return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes, return create_communicator_grouped2(comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
...@@ -447,13 +465,11 @@ void destroy_communicator(communicator *comm) { ...@@ -447,13 +465,11 @@ void destroy_communicator(communicator *comm) {
if (comm->use_mc) { if (comm->use_mc) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle); NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle);
} }
free(comm->fifo);
delete comm; delete comm;
} }
void destroy_communicator_mpi(communicator *comm) { void destroy_communicator_mpi(communicator *comm) {
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
MPI_Comm_free(static_cast<MPI_Comm *>(&(comm->comm_inter)));
MPI_Comm_free(static_cast<MPI_Comm *>(&(comm->comm_intra))); MPI_Comm_free(static_cast<MPI_Comm *>(&(comm->comm_intra)));
destroy_communicator(comm); destroy_communicator(comm);
#else #else
...@@ -472,6 +488,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -472,6 +488,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
#if CUDART_VERSION >= 12010 #if CUDART_VERSION >= 12010
if (comm->use_mc && alloc) { if (comm->use_mc && alloc) {
bool mnnvl_fabric = has_mnnvl_fabric(comm->mydev);
int nranks = comm->nvsize; // total GPUs in NVLINK domain int nranks = comm->nvsize; // total GPUs in NVLINK domain
int myrank = comm->nvrank; int myrank = comm->nvrank;
void **remptrs = reinterpret_cast<void **>(malloc(nranks * sizeof(void *))); void **remptrs = reinterpret_cast<void **>(malloc(nranks * sizeof(void *)));
...@@ -481,7 +498,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -481,7 +498,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = comm->mydev; prop.location.id = comm->mydev;
prop.requestedHandleTypes = prop.requestedHandleTypes =
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; // CU_MEM_HANDLE_TYPE_FABRIC; mnnvl_fabric ? CU_MEM_HANDLE_TYPE_FABRIC : CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
size_t granularity = 0; size_t granularity = 0;
NVTE_CALL_CHECK_CUDA_DRIVER( NVTE_CALL_CHECK_CUDA_DRIVER(
...@@ -507,41 +524,58 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -507,41 +524,58 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop, NVTE_CALL_CHECK_CUDA_DRIVER(cuMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop,
(uint64_t)0); (uint64_t)0);
int *peerfd = reinterpret_cast<int *>(malloc(nranks * sizeof(int))); if (mnnvl_fabric) {
NVTE_CALL_CHECK_CUDA_DRIVER( CUmemFabricHandle *exphndl;
cuMemExportToShareableHandle, reinterpret_cast<void *>(&peerfd[myrank]), CUmemFabricHandle myhndl;
comm->uchandles[hndl][myrank], NVTE_CALL_CHECK_CUDA_DRIVER(cuMemExportToShareableHandle, &myhndl,
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR), comm->uchandles[hndl][myrank], CU_MEM_HANDLE_TYPE_FABRIC, 0);
(uint64_t)0); NVTE_CHECK_CUDA(cudaMallocHost(&exphndl, comm->nvsize * sizeof(CUmemFabricHandle)));
comm->_allgather(reinterpret_cast<void *>(exphndl), comm->nvsize * sizeof(CUmemFabricHandle),
volatile uint32_t abortFlag = 0; reinterpret_cast<void *>(&myhndl), sizeof(CUmemFabricHandle),
IpcSocketHandle ipcSock = {0}; comm->comm_intra);
uint64_t opId = 0xdeadcafebeef; for (int p = 0; p < nranks; p++)
ipcSocketResult_t ret = ipcSocketSuccess; if (p != myrank)
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemImportFromShareableHandle, &comm->uchandles[hndl][p],
// All-gather POSIX file descriptors across local ranks reinterpret_cast<void *>(&exphndl[p]),
IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag)); CU_MEM_HANDLE_TYPE_FABRIC);
for (int p = 1; p < nranks; p++) { NVTE_CHECK_CUDA(cudaFreeHost(exphndl));
int send_to = (myrank + p) % nranks; } else {
int recv_from = (myrank + nranks - p) % nranks; int *peerfd = reinterpret_cast<int *>(malloc(nranks * sizeof(int)));
comm->_barrier(comm->comm_intra); NVTE_CALL_CHECK_CUDA_DRIVER(
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret, error); cuMemExportToShareableHandle, reinterpret_cast<void *>(&peerfd[myrank]),
IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error); comm->uchandles[hndl][myrank],
} static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR),
(uint64_t)0);
error: volatile uint32_t abortFlag = 0;
IPCCHECK(ipcSocketClose(&ipcSock)); IpcSocketHandle ipcSock = {0};
uint64_t opId = 0xdeadcafebeef + comm->my_node;
ipcSocketResult_t ret = ipcSocketSuccess;
// All-gather POSIX file descriptors across local ranks
IPCCHECK(ipcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag));
for (int p = 1; p < nranks; p++) {
int send_to = (myrank + p) % nranks;
int recv_from = (myrank + nranks - p) % nranks;
comm->_barrier(comm->comm_intra);
IPCCHECKGOTO(ipcSocketSendFd(&ipcSock, peerfd[myrank], send_to, (uint64_t)opId), ret,
error);
IPCCHECKGOTO(ipcSocketRecvFd(&ipcSock, &peerfd[recv_from]), ret, error);
}
for (int p = 0; p < nranks; p++) { error:
if (p != myrank) IPCCHECK(ipcSocketClose(&ipcSock));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &comm->uchandles[hndl][p],
reinterpret_cast<void *>(peerfd[p]),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(peerfd[p]);
}
free(peerfd);
for (int p = 0; p < nranks; p++) {
if (p != myrank)
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &comm->uchandles[hndl][p],
reinterpret_cast<void *>(peerfd[p]),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(peerfd[p]);
}
free(peerfd);
}
CUdeviceptr ptr; CUdeviceptr ptr;
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks), NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks),
(size_t)0, (CUdeviceptr)0, (uint64_t)0); (size_t)0, (CUdeviceptr)0, (uint64_t)0);
...@@ -571,13 +605,13 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * ...@@ -571,13 +605,13 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
cudaMemcpy((reinterpret_cast<char *>(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)), cudaMemcpy((reinterpret_cast<char *>(comm->gpu_ptrs)) + (hndl * nranks * sizeof(void *)),
remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice)); remptrs, nranks * sizeof(void *), cudaMemcpyHostToDevice));
free(remptrs); free(remptrs);
comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED; comm->memflags[hndl] = NVTE_UB_MEM_UC_CONTIG | NVTE_UB_MEM_ALLOCATED;
if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) { if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) {
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastBindMem, comm->mc_handle, comm->mc_offset, NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastBindMem, comm->mc_handle, comm->mc_offset,
comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/, comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/,
aligned_size, (uint64_t)0); aligned_size, (uint64_t)0);
comm->memflags[hndl] |= UB_MEM_MC_CREATED; comm->memflags[hndl] |= NVTE_UB_MEM_MC_CREATED;
comm->mc_ptr[hndl] = reinterpret_cast<char *>(comm->mc_baseptr) + comm->mc_offset; comm->mc_ptr[hndl] = reinterpret_cast<char *>(comm->mc_baseptr) + comm->mc_offset;
comm->mc_offset += aligned_size; comm->mc_offset += aligned_size;
} else if (!comm->myrank) { } else if (!comm->myrank) {
......
...@@ -1682,6 +1682,7 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int ...@@ -1682,6 +1682,7 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8)
callranks_rs_oop_stride(16) callranks_rs_oop_stride(32)
} }
void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset,
const int rowelements, const int colelements, const int rowelements, const int colelements,
...@@ -1703,7 +1704,8 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con ...@@ -1703,7 +1704,8 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4)
callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16)
callranks_rs_oop_stride_atomic(32)
} }
template <typename fp8type> template <typename fp8type>
...@@ -1729,6 +1731,7 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c ...@@ -1729,6 +1731,7 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8)
callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32)
} }
template <typename fp8type> template <typename fp8type>
...@@ -1773,7 +1776,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler ...@@ -1773,7 +1776,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4)
callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16)
callranks_rs_oop_stride_multiatomic(32)
} }
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
...@@ -1793,17 +1797,17 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int ...@@ -1793,17 +1797,17 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32)
} else { } else {
callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32)
} else { } else {
callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32)
} }
} }
} }
...@@ -1840,17 +1844,17 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const ...@@ -1840,17 +1844,17 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32)
} else { } else {
callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32)
} else { } else {
callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32)
} }
} }
} }
...@@ -1873,17 +1877,21 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -1873,17 +1877,21 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16)
callranks_rs_oopMC(32)
} else { } else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16)
callranks_rs_oop(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16)
callranks_rs_oopMC(32)
} else { } else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16)
callranks_rs_oop(32)
} }
} }
} }
...@@ -1915,10 +1923,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -1915,10 +1923,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16)
callranks_rs_oop_fp8(32)
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16)
callranks_rs_oop_fp8(32)
} }
} }
......
...@@ -34,11 +34,7 @@ using ExtBarrierOp = std::function<void(ExtComm)>; ...@@ -34,11 +34,7 @@ using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REQUESTS 1024 #define NVTE_MAX_REQUESTS 1024
#define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_GPU 1
#define NVTE_LAUNCH_CPU 2 #define NVTE_LAUNCH_CPU 2
#define NVTE_MAX_NVLINK 8 #define NVTE_MAX_NVLINK 32
#define UB_MEM_UC_CONTIG 1
#define UB_MEM_MC_CREATED 2
#define UB_MEM_ALLOCATED 4
#define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_UC_CONTIG 1
#define NVTE_UB_MEM_MC_CREATED 2 #define NVTE_UB_MEM_MC_CREATED 2
...@@ -124,11 +120,8 @@ struct communicator { ...@@ -124,11 +120,8 @@ struct communicator {
ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup
// (_splitar init used) would be equal to (nvsize,0) for regular comm_create // (_splitar init used) would be equal to (nvsize,0) for regular comm_create
int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step
int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size)
int sm_arch; int sm_arch;
int num_nodes, my_node, int num_nodes, my_node;
first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes)
int num2_nodes, my2_node, first2_node; // with num_nodes as a stride
// max value for running block counters in hostflags // max value for running block counters in hostflags
int basecounter[userbuffers_op_types]; // NOLINT(*) int basecounter[userbuffers_op_types]; // NOLINT(*)
...@@ -136,20 +129,11 @@ struct communicator { ...@@ -136,20 +129,11 @@ struct communicator {
void *mem_mr[NVTE_MAX_REGIONS]; void *mem_mr[NVTE_MAX_REGIONS];
ub_request *fifo;
int nblocks, alignblock, minblock, asyncblocks, active_nreqs;
ub_request active_req[userbuffers_op_types]; // NOLINT(*)
int padding[7];
volatile int head;
int padding2[15];
volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
ExtAllgatherOp _allgather; ExtAllgatherOp _allgather;
ExtBarrierOp _barrier; ExtBarrierOp _barrier;
ExtComm comm_world; ExtComm comm_world;
ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail
ExtComm comm_intra; // full intranode (all ndev GPUS) ExtComm comm_intra; // full intranode (all ndev GPUS)
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP]; MPI_Request mpihndl[NVTE_MAX_SHARP];
...@@ -199,11 +183,6 @@ void destroy_communicator_mpi(communicator *comm); ...@@ -199,11 +183,6 @@ void destroy_communicator_mpi(communicator *comm);
returned offset is offset of gpubuff relative to buffer registered returned offset is offset of gpubuff relative to buffer registered
*/ */
int pipe_rank(communicator *comm,
int step); // helper function to help walk across allreduce1 x allreduce2 groups
// data-parallel and tensor-parallel position within data and tensor
// groups would be preserved
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc); int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc);
/* returns handler and registers buffers. assumed to be collective i.e. you use same groups and /* returns handler and registers buffers. assumed to be collective i.e. you use same groups and
dont mix buffers for different operations returns -1 if cant register (too many preregistered dont mix buffers for different operations returns -1 if cant register (too many preregistered
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include "../common.h" #include "../common.h"
...@@ -13,84 +11,6 @@ ...@@ -13,84 +11,6 @@
namespace transformer_engine { namespace transformer_engine {
namespace {
/*! \brief Wrapper class for a shared library
*
* \todo Windows support
*/
class Library {
public:
explicit Library(const char *filename) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL);
NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed");
#endif // _WIN32 or _WIN64 or __WINDOW__
}
~Library() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
#else
if (handle_ != nullptr) {
dlclose(handle_);
}
#endif // _WIN32 or _WIN64 or __WINDOW__
}
Library(const Library &) = delete; // move-only
Library(Library &&other) noexcept { swap(*this, other); }
Library &operator=(Library other) noexcept {
// Copy-and-swap idiom
swap(*this, other);
return *this;
}
friend void swap(Library &first, Library &second) noexcept;
void *get() noexcept { return handle_; }
const void *get() const noexcept { return handle_; }
/*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
void *ptr = dlsym(handle_, symbol);
NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library");
return ptr;
#endif // _WIN32 or _WIN64 or __WINDOW__
}
private:
void *handle_ = nullptr;
};
void swap(Library &first, Library &second) noexcept {
using std::swap;
swap(first.handle_, second.handle_);
}
/*! \brief Lazily-initialized shared library for CUDA driver */
Library &cuda_driver_lib() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
constexpr char lib_name[] = "nvcuda.dll";
#else
constexpr char lib_name[] = "libcuda.so.1";
#endif
static Library lib(lib_name);
return lib;
}
} // namespace
namespace cuda_driver { namespace cuda_driver {
void *get_symbol(const char *symbol) { void *get_symbol(const char *symbol) {
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "cuda_nvml.h"
#include "shared_lib_wrapper.h"
namespace transformer_engine {
namespace cuda_nvml {
/*! \brief Lazily-initialized shared library for CUDA NVML */
Library &cuda_nvml_lib() {
constexpr char lib_name[] = "libnvidia-ml.so.1";
static Library lib(lib_name);
return lib;
}
void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); }
} // namespace cuda_nvml
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#include <nvml.h>
#include <string>
#include "../common.h"
#include "../util/string.h"
namespace transformer_engine {
namespace cuda_nvml {
/*! \brief Get pointer corresponding to symbol in CUDA NVML library */
void *get_symbol(const char *symbol);
/*! \brief Call function in CUDA NVML library
*
* The CUDA NVML library (libnvidia-ml.so.1 on Linux) may be different at
* compile-time and run-time.
*
* \param[in] symbol Function name
* \param[in] args Function arguments
*/
template <typename... ArgTs>
inline nvmlReturn_t call(const char *symbol, ArgTs... args) {
using FuncT = nvmlReturn_t(ArgTs...);
FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));
return (*func)(args...);
}
/*! \brief Get NVML error string
*
* \param[in] rc NVML return code
*/
inline const char *get_nvml_error_string(nvmlReturn_t rc) {
using FuncT = const char *(nvmlReturn_t);
FuncT *func = reinterpret_cast<FuncT *>(get_symbol("nvmlErrorString"));
return (*func)(rc);
}
} // namespace cuda_nvml
} // namespace transformer_engine
#define NVTE_CHECK_CUDA_NVML(expr) \
do { \
const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \
if (status_NVTE_CHECK_CUDA_NVML != NVML_SUCCESS) { \
const char *desc_NVTE_CHECK_CUDA_NVML = \
::transformer_engine::cuda_nvml::get_nvml_error_string(status_NVTE_CHECK_CUDA_NVML); \
NVTE_ERROR("NVML Error: ", desc_NVTE_CHECK_CUDA_NVML); \
} \
} while (false)
#define VA_ARGS(...) , ##__VA_ARGS__
#define NVTE_CALL_CHECK_CUDA_NVML(symbol, ...) \
do { \
NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_
#include <dlfcn.h>
namespace transformer_engine {
/*! \brief Wrapper class for a shared library
*
* \todo Windows support
*/
class Library {
public:
explicit Library(const char *filename) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL);
NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed");
#endif // _WIN32 or _WIN64 or __WINDOW__
}
~Library() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
#else
if (handle_ != nullptr) {
dlclose(handle_);
}
#endif // _WIN32 or _WIN64 or __WINDOW__
}
Library(const Library &) = delete; // move-only
void *get() noexcept { return handle_; }
const void *get() const noexcept { return handle_; }
/*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
void *ptr = dlsym(handle_, symbol);
NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library");
return ptr;
#endif // _WIN32 or _WIN64 or __WINDOW__
}
private:
void *handle_ = nullptr;
};
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_
...@@ -390,8 +390,7 @@ class CommOverlapHelper : torch::CustomClassHolder { ...@@ -390,8 +390,7 @@ class CommOverlapHelper : torch::CustomClassHolder {
CommOverlapHelper(); CommOverlapHelper();
CommOverlapHelper(c10d::ProcessGroup *world_group, CommOverlapHelper(c10d::ProcessGroup *world_group,
std::optional<c10d::ProcessGroup *> intra_node_group, std::optional<c10d::ProcessGroup *> intra_node_group);
std::optional<c10d::ProcessGroup *> inter_node_group);
~CommOverlapHelper(); ~CommOverlapHelper();
......
...@@ -26,8 +26,7 @@ CommOverlapHelper::CommOverlapHelper() { ...@@ -26,8 +26,7 @@ CommOverlapHelper::CommOverlapHelper() {
} // empty constructor for NVTE_UB_WITH_MPI=1 } // empty constructor for NVTE_UB_WITH_MPI=1
CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group,
std::optional<c10d::ProcessGroup *> intra_domain_group, std::optional<c10d::ProcessGroup *> intra_domain_group) {
std::optional<c10d::ProcessGroup *> inter_domain_group) {
#ifndef NVTE_UB_WITH_MPI #ifndef NVTE_UB_WITH_MPI
pgs.insert({"world", world_group}); pgs.insert({"world", world_group});
myrank = pgs["world"]->getRank(); myrank = pgs["world"]->getRank();
...@@ -53,20 +52,9 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, ...@@ -53,20 +52,9 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group,
mynode = 0; mynode = 0;
numnodes = 1; numnodes = 1;
} else { } else {
// Intra-node group is different than the world group so there must be multiple nodes
NVTE_CHECK(
inter_domain_group.has_value(),
"Internal TE error: Inter-node group cannot be `None` when intra-node group is not ",
"identical to the world_group!");
// Get node ID and number of nodes // Get node ID and number of nodes
NVTE_CHECK( mynode = myrank / numlocal;
inter_domain_group.value()->getBackendType() == backend, numnodes = numranks / numlocal;
"Internal TE error: Inter-node group must be on the same backend (%s) as the world ",
"group!", pgs["world"]->getBackendName());
pgs.insert({"inter", inter_domain_group.value()});
mynode = pgs["inter"]->getRank();
numnodes = pgs["inter"]->getSize();
} }
} else { } else {
// Intra-node group is not set so we assume there is only 1 node // Intra-node group is not set so we assume there is only 1 node
......
...@@ -285,10 +285,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -285,10 +285,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<CommOverlapHelper>(m, "CommOverlapHelper") py::class_<CommOverlapHelper>(m, "CommOverlapHelper")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>()) .def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, std::optional<c10d::ProcessGroup *>, .def(py::init<c10d::ProcessGroup *, std::optional<c10d::ProcessGroup *>>(),
std::optional<c10d::ProcessGroup *>>(),
py::call_guard<py::gil_scoped_release>(), py::arg("world_group"), py::call_guard<py::gil_scoped_release>(), py::arg("world_group"),
py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); py::arg("intra_node_group") = py::none());
py::class_<CommOverlap, std::shared_ptr<CommOverlap>, transformer_engine::CommOverlapBase, py::class_<CommOverlap, std::shared_ptr<CommOverlap>, transformer_engine::CommOverlapBase,
transformer_engine::CommOverlapCore>(m, "CommOverlap") transformer_engine::CommOverlapCore>(m, "CommOverlap")
......
...@@ -7,9 +7,6 @@ import io ...@@ -7,9 +7,6 @@ import io
import os import os
import pickle import pickle
import warnings import warnings
import socket
import fcntl
import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
...@@ -177,85 +174,32 @@ def initialize_ub( ...@@ -177,85 +174,32 @@ def initialize_ub(
world_rank = torch.distributed.get_rank(world_group) world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group) world_size = torch.distributed.get_world_size(world_group)
# We have single-node NVLink so we can color based on physical node hostnames. num_domains = world_size // tp_size
# NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and mydomain_idx = world_rank // tp_size
# otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on
# the chosen bootstrap backend.
mydomain = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME")
)
if ifname is not None:
# Make sure the ifname found in the environment is a valid network interface
if ifname in [name for _, name in socket.if_nameindex()]:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
mydomain = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
finally:
s.close()
else:
ifname_warning = (
f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will"
+ " attempt to detect ranks on the same node by matching "
+ "'socket.gethostname()', which is known to fail on virtual clusters like "
+ "Kubernetes. If Userbuffers initialization fails, please set the "
+ "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network "
+ "interface."
)
warnings.warn(ifname_warning, UserWarning)
# Allgather the domain colors across ranks and reduce to a list of unique domains
domain_per_rank_list = [None for _ in range(world_size)]
torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group)
unique_domains = []
for domain in domain_per_rank_list:
if domain not in unique_domains:
unique_domains.append(domain)
num_domains = len(unique_domains)
if num_domains > 1: if num_domains > 1:
# DP/TP model replicated on multiple NVLink domains ranks_per_domain_list = [
ranks_per_domain_list = [[] for _ in range(num_domains)] [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
mydomain_idx = -1 ]
for i, domain in enumerate(domain_per_rank_list): tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
domain_idx = unique_domains.index(domain)
ranks_per_domain_list[domain_idx].append(i)
if domain == mydomain:
mydomain_idx = domain_idx
assert mydomain_idx >= 0, "Internal TE error!"
intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_domain_list, backend=bootstrap_backend ranks_per_domain_list, backend=bootstrap_backend
) )
local_rank = torch.distributed.get_rank(intra_domain_group) local_rank = torch.distributed.get_rank(tp_domain_group)
intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
[list(ranks) for ranks in zip(*ranks_per_domain_list)],
backend=bootstrap_backend,
)
helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group)
helper = tex.CommOverlapHelper(world_group, tp_domain_group)
else: else:
# TP model on single NVLink domain, no replication, no data-parallelism # TP model on single NVLink domain, no replication, no data-parallelism
mydomain_idx = 0 mydomain_idx = 0
local_rank = world_rank local_rank = world_rank
intra_domain_ranks = list(range(world_size)) tp_domain_ranks = list(range(world_size))
helper = tex.CommOverlapHelper(world_group) helper = tex.CommOverlapHelper(world_group)
if world_rank == 0: if world_rank == 0:
print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
if local_rank == 0: if local_rank == 0:
print( print(
f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
end="", end="",
flush=True, flush=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment