Unverified Commit f8eddcf9 authored by Nicolas Castet's avatar Nicolas Castet Committed by GitHub
Browse files

Add support for UB MNNVL (#1470)



* Add support for UB MNNVL
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Address review comments
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Fix lint
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Dlopen nvml lib since it comes with the cuda driver
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

* Add initial copyright date
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>

---------
Signed-off-by: default avatarNicolas Castet <ncastet@nvidia.com>
parent e85d1806
...@@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None): ...@@ -77,7 +77,16 @@ def _parse_args(argv=None, namespace=None):
help="Disable the comm+GEMM overlap.", help="Disable the comm+GEMM overlap.",
) )
parser.add_argument( parser.add_argument(
"--num-replicas", type=int, default=1, help="Number of data-parallel model replicas." "--num-replicas",
type=int,
default=1,
help="Number of data-parallel model replicas per node.",
)
parser.add_argument(
"--use-global-replica-count",
action="store_true",
default=False,
help="Treat '--num-replicas' as the total number of replicas.",
) )
parser.add_argument( parser.add_argument(
"--tcp-init", "--tcp-init",
...@@ -173,13 +182,12 @@ def _train(opts): ...@@ -173,13 +182,12 @@ def _train(opts):
opts.tcp_init = True opts.tcp_init = True
opts.bind_to_device = True opts.bind_to_device = True
opts.bootstrap_backend = "mpi" opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ: else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!")
NUM_NODES = WORLD_SIZE // LOCAL_SIZE NUM_NODES = WORLD_SIZE // LOCAL_SIZE
# Initialize torch.distributed global process group and get DP/TP groups # Initialize torch.distributed global process group and get DP/TP groups
...@@ -214,90 +222,24 @@ def _train(opts): ...@@ -214,90 +222,24 @@ def _train(opts):
dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs") dist_print(f"Initialized default NCCL process group with {WORLD_SIZE} GPUs")
# Figure out process groups for tensor- and data-parallelism (if any) total_replicas = (
if NUM_NODES > 1: opts.num_replicas if opts.use_global_replica_count else opts.num_replicas * NUM_NODES
# Create a list of world ranks on this node )
hostname = socket.gethostname() tp_size = WORLD_SIZE // total_replicas
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME",
os.getenv("NCCL_SOCKET_IFNAME", os.getenv("GLOO_SOCKET_IFNAME")),
)
if ifname is not None:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
hostname = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
hostnames = [None for _ in range(WORLD_SIZE)]
dist.all_gather_object(hostnames, hostname)
unique_hosts = []
for host in hostnames:
if host not in unique_hosts:
unique_hosts.append(host)
assert len(unique_hosts) == NUM_NODES
ranks_per_node_list = [[] for _ in range(NUM_NODES)]
self_node_idx = -1
for i, host in enumerate(hostnames):
node_idx = unique_hosts.index(host)
ranks_per_node_list[node_idx].append(i)
if host == hostname:
self_node_idx = node_idx
assert self_node_idx >= 0
self_node_ranks = ranks_per_node_list[self_node_idx]
if opts.num_replicas > 1:
# Split node ranks into multiple replicas
assert len(self_node_ranks) % opts.num_replicas == 0
tp_size = len(self_node_ranks) // opts.num_replicas
ranks_per_replica_list = []
for node_ranks in ranks_per_node_list:
for i in range(opts.num_replicas):
start = i * tp_size
end = start + tp_size
ranks_per_replica_list.append(node_ranks[start:end])
self_replica_idx = -1
for i, replica_ranks in enumerate(ranks_per_replica_list):
if WORLD_RANK in replica_ranks:
self_replica_idx = i
break
assert self_replica_idx >= 0
else: if total_replicas > 1:
# The entire node is the tensor-parallel group ranks_per_replica_list = [
ranks_per_replica_list = ranks_per_node_list [i * tp_size + t for t in range(tp_size)] for i in range(total_replicas)
self_replica_idx = self_node_idx ]
tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl") tp_group, _ = dist.new_subgroups_by_enumeration(ranks_per_replica_list, backend="nccl")
ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32) ranks_per_replica_tensor = torch.tensor(ranks_per_replica_list, dtype=torch.int32)
dp_group, _ = dist.new_subgroups_by_enumeration( dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl" ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
) )
else: else:
if opts.num_replicas > 1: dp_group = None
# Mixed data- and tensor-parallelism on a single node tp_group = nccl_world
# NOTE: Avoid dist.init_device_mesh() to support older PyTorch versions
all_ranks = torch.tensor(list(range(LOCAL_SIZE)), dtype=torch.uint8, device="cpu")
ranks_per_replica_tensor = all_ranks.reshape(
(opts.num_replicas, LOCAL_SIZE // opts.num_replicas)
)
tp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.tolist(), backend="nccl"
)
dp_group, _ = dist.new_subgroups_by_enumeration(
ranks_per_replica_tensor.transpose(0, 1).tolist(), backend="nccl"
)
else:
dp_group = None
tp_group = nccl_world
tp_rank = dist.get_rank(tp_group) tp_rank = dist.get_rank(tp_group)
tp_size = dist.get_world_size(tp_group) tp_size = dist.get_world_size(tp_group)
......
...@@ -180,15 +180,22 @@ def _main(opts): ...@@ -180,15 +180,22 @@ def _main(opts):
LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1")) LOCAL_SIZE = int(os.getenv("OMPI_COMM_WORLD_LOCAL_SIZE", "1"))
opts.tcp_init = True opts.tcp_init = True
opts.bootstrap_backend = "mpi" opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ: else: # TORCHELASTIC, SLURM, etc...
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") result = subprocess.run(
assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
assert LOCAL_SIZE <= torch.cuda.device_count() capture_output=True,
text=True,
shell=True,
)
if result.stdout == "0": # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE
assert LOCAL_SIZE <= torch.cuda.device_count()
# Fix clock speed # Fix clock speed
torch.cuda.set_device(LOCAL_RANK) torch.cuda.set_device(LOCAL_RANK)
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
import os import os
import sys import sys
import socket import socket
import subprocess
import argparse import argparse
import warnings import warnings
import pprint import pprint
...@@ -209,14 +210,21 @@ def _train(opts): ...@@ -209,14 +210,21 @@ def _train(opts):
opts.tcp_init = True opts.tcp_init = True
opts.bind_to_device = True opts.bind_to_device = True
opts.bootstrap_backend = "mpi" opts.bootstrap_backend = "mpi"
elif "TORCHELASTIC_RUN_ID" in os.environ: else:
WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_RANK = int(os.getenv("RANK", "0"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", str(torch.cuda.device_count())))
else:
raise RuntimeError(f"{__file__} must be launched with either `mpirun` or `torchrun`!") result = subprocess.run(
assert LOCAL_SIZE == WORLD_SIZE "nvidia-smi -q | grep -m1 CliqueId | awk '{printf $3}'",
capture_output=True,
text=True,
shell=True,
)
if result.stdout == "0": # Extra checks for non-MNNVL platforms
assert WORLD_SIZE == LOCAL_SIZE
def dist_print(msg, src=None, end="\n", debug=False, error=False): def dist_print(msg, src=None, end="\n", debug=False, error=False):
if debug and not opts.debug: if debug and not opts.debug:
...@@ -227,7 +235,7 @@ def _train(opts): ...@@ -227,7 +235,7 @@ def _train(opts):
dist.barrier() dist.barrier()
# Set device and initialize RNG states # Set device and initialize RNG states
torch.cuda.set_device(WORLD_RANK) torch.cuda.set_device(LOCAL_RANK)
torch.manual_seed(opts.seed) torch.manual_seed(opts.seed)
torch.cuda.manual_seed(opts.seed) torch.cuda.manual_seed(opts.seed)
...@@ -312,7 +320,7 @@ def _train(opts): ...@@ -312,7 +320,7 @@ def _train(opts):
return out return out
torch_rng_state = torch.get_rng_state() torch_rng_state = torch.get_rng_state()
cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{WORLD_RANK}")) cuda_rng_state = torch.cuda.get_rng_state(torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
test_graph = torch.cuda.CUDAGraph() test_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(test_graph): with torch.cuda.graph(test_graph):
...@@ -329,7 +337,7 @@ def _train(opts): ...@@ -329,7 +337,7 @@ def _train(opts):
names.append(test_name + ".grad") names.append(test_name + ".grad")
torch.set_rng_state(torch_rng_state) torch.set_rng_state(torch_rng_state)
torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{WORLD_RANK}")) torch.cuda.set_rng_state(cuda_rng_state, torch.device(f"cuda:{LOCAL_RANK}"))
if opts.use_cuda_graphs: if opts.use_cuda_graphs:
ref_graph = torch.cuda.CUDAGraph() ref_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(ref_graph): with torch.cuda.graph(ref_graph):
......
...@@ -78,6 +78,7 @@ list(APPEND transformer_engine_SOURCES ...@@ -78,6 +78,7 @@ list(APPEND transformer_engine_SOURCES
util/cast.cu util/cast.cu
util/padding.cu util/padding.cu
util/cuda_driver.cpp util/cuda_driver.cpp
util/cuda_nvml.cpp
util/cuda_runtime.cpp util/cuda_runtime.cpp
util/rtc.cpp util/rtc.cpp
swizzle/swizzle.cu swizzle/swizzle.cu
......
...@@ -1682,6 +1682,7 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int ...@@ -1682,6 +1682,7 @@ void reducescatter2_userbuff_strided(void *output, const int handler, const int
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8) callranks_rs_oop_stride(2) callranks_rs_oop_stride(4) callranks_rs_oop_stride(8)
callranks_rs_oop_stride(16) callranks_rs_oop_stride(32)
} }
void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset, void reducescatter2_userbuff_strided_atomic(void *output, const int handler, const int offset,
const int rowelements, const int colelements, const int rowelements, const int colelements,
...@@ -1703,7 +1704,8 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con ...@@ -1703,7 +1704,8 @@ void reducescatter2_userbuff_strided_atomic(void *output, const int handler, con
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4) callranks_rs_oop_stride_atomic(2) callranks_rs_oop_stride_atomic(4)
callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(8) callranks_rs_oop_stride_atomic(16)
callranks_rs_oop_stride_atomic(32)
} }
template <typename fp8type> template <typename fp8type>
...@@ -1729,6 +1731,7 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c ...@@ -1729,6 +1731,7 @@ void reducescatter2_userbuff_strided_universal_fp8(void *output, float *scale, c
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8) callranks_rs_oop_atomic_fp8(2) callranks_rs_oop_atomic_fp8(4) callranks_rs_oop_atomic_fp8(8)
callranks_rs_oop_atomic_fp8(16) callranks_rs_oop_atomic_fp8(32)
} }
template <typename fp8type> template <typename fp8type>
...@@ -1773,7 +1776,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler ...@@ -1773,7 +1776,8 @@ void reducescatter2_userbuff_strided_multiatomic(void *output, const int handler
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4) callranks_rs_oop_stride_multiatomic(2) callranks_rs_oop_stride_multiatomic(4)
callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(8) callranks_rs_oop_stride_multiatomic(16)
callranks_rs_oop_stride_multiatomic(32)
} }
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements, void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
...@@ -1793,17 +1797,17 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int ...@@ -1793,17 +1797,17 @@ void allgather2_userbuff_inplace(const int handler, const int offset, const int
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32)
} else { } else {
callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(2) callranks_agMC(4) callranks_agMC(8) callranks_agMC(16) callranks_agMC(32)
} else { } else {
callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(2) callranks_ag(4) callranks_ag(8) callranks_ag(16) callranks_ag(32)
} }
} }
} }
...@@ -1840,17 +1844,17 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const ...@@ -1840,17 +1844,17 @@ void reducescatter2_userbuff_inplace(const int handler, const int offset, const
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32)
} else { } else {
callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(2) callranks_rsMC(4) callranks_rsMC(8) callranks_rsMC(16) callranks_rsMC(32)
} else { } else {
callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(2) callranks_rs(4) callranks_rs(8) callranks_rs(16) callranks_rs(32)
} }
} }
} }
...@@ -1873,17 +1877,21 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons ...@@ -1873,17 +1877,21 @@ void reducescatter2_userbuff_stridedoutput(void *output, const int handler, cons
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16)
callranks_rs_oopMC(32)
} else { } else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16)
callranks_rs_oop(32)
} }
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
if (comm->use_mc && (comm->memflags[handler] & UB_MEM_MC_CREATED)) { if (comm->use_mc && (comm->memflags[handler] & NVTE_UB_MEM_MC_CREATED)) {
callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(2) callranks_rs_oopMC(4) callranks_rs_oopMC(8) callranks_rs_oopMC(16)
callranks_rs_oopMC(32)
} else { } else {
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8) callranks_rs_oop(16)
callranks_rs_oop(32)
} }
} }
} }
...@@ -1915,10 +1923,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const ...@@ -1915,10 +1923,12 @@ void reducescatter2_userbuff_stridedoutput_fp8(void *output, float *scale, const
if (comm_launch_event) { if (comm_launch_event) {
SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event); SETUP_LAUNCH_CONFIG_WITH_COMPLETION_EVENT(sms, warps * 32, stream, comm_launch_event);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16)
callranks_rs_oop_fp8(32)
} else { } else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream); SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(2) callranks_rs_oop_fp8(4) callranks_rs_oop_fp8(8) callranks_rs_oop_fp8(16)
callranks_rs_oop_fp8(32)
} }
} }
......
...@@ -34,11 +34,7 @@ using ExtBarrierOp = std::function<void(ExtComm)>; ...@@ -34,11 +34,7 @@ using ExtBarrierOp = std::function<void(ExtComm)>;
#define NVTE_MAX_REQUESTS 1024 #define NVTE_MAX_REQUESTS 1024
#define NVTE_LAUNCH_GPU 1 #define NVTE_LAUNCH_GPU 1
#define NVTE_LAUNCH_CPU 2 #define NVTE_LAUNCH_CPU 2
#define NVTE_MAX_NVLINK 8 #define NVTE_MAX_NVLINK 32
#define UB_MEM_UC_CONTIG 1
#define UB_MEM_MC_CREATED 2
#define UB_MEM_ALLOCATED 4
#define NVTE_UB_MEM_UC_CONTIG 1 #define NVTE_UB_MEM_UC_CONTIG 1
#define NVTE_UB_MEM_MC_CREATED 2 #define NVTE_UB_MEM_MC_CREATED 2
...@@ -124,11 +120,8 @@ struct communicator { ...@@ -124,11 +120,8 @@ struct communicator {
ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup
// (_splitar init used) would be equal to (nvsize,0) for regular comm_create // (_splitar init used) would be equal to (nvsize,0) for regular comm_create
int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step
int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size)
int sm_arch; int sm_arch;
int num_nodes, my_node, int num_nodes, my_node;
first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes)
int num2_nodes, my2_node, first2_node; // with num_nodes as a stride
// max value for running block counters in hostflags // max value for running block counters in hostflags
int basecounter[userbuffers_op_types]; // NOLINT(*) int basecounter[userbuffers_op_types]; // NOLINT(*)
...@@ -136,20 +129,11 @@ struct communicator { ...@@ -136,20 +129,11 @@ struct communicator {
void *mem_mr[NVTE_MAX_REGIONS]; void *mem_mr[NVTE_MAX_REGIONS];
ub_request *fifo;
int nblocks, alignblock, minblock, asyncblocks, active_nreqs;
ub_request active_req[userbuffers_op_types]; // NOLINT(*)
int padding[7];
volatile int head;
int padding2[15];
volatile int tail;
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks) // Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
ExtAllgatherOp _allgather; ExtAllgatherOp _allgather;
ExtBarrierOp _barrier; ExtBarrierOp _barrier;
ExtComm comm_world; ExtComm comm_world;
ExtComm comm_inter; // reduction group communicator (subset of the nodes) along GPU rail
ExtComm comm_intra; // full intranode (all ndev GPUS) ExtComm comm_intra; // full intranode (all ndev GPUS)
#ifdef NVTE_UB_WITH_MPI #ifdef NVTE_UB_WITH_MPI
MPI_Request mpihndl[NVTE_MAX_SHARP]; MPI_Request mpihndl[NVTE_MAX_SHARP];
...@@ -199,11 +183,6 @@ void destroy_communicator_mpi(communicator *comm); ...@@ -199,11 +183,6 @@ void destroy_communicator_mpi(communicator *comm);
returned offset is offset of gpubuff relative to buffer registered returned offset is offset of gpubuff relative to buffer registered
*/ */
int pipe_rank(communicator *comm,
int step); // helper function to help walk across allreduce1 x allreduce2 groups
// data-parallel and tensor-parallel position within data and tensor
// groups would be preserved
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc); int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc);
/* returns handler and registers buffers. assumed to be collective i.e. you use same groups and /* returns handler and registers buffers. assumed to be collective i.e. you use same groups and
dont mix buffers for different operations returns -1 if cant register (too many preregistered dont mix buffers for different operations returns -1 if cant register (too many preregistered
......
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include <dlfcn.h>
#include <filesystem> #include <filesystem>
#include "../common.h" #include "../common.h"
...@@ -13,84 +11,6 @@ ...@@ -13,84 +11,6 @@
namespace transformer_engine { namespace transformer_engine {
namespace {
/*! \brief Wrapper class for a shared library
*
* \todo Windows support
*/
class Library {
public:
explicit Library(const char *filename) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL);
NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed");
#endif // _WIN32 or _WIN64 or __WINDOW__
}
~Library() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
#else
if (handle_ != nullptr) {
dlclose(handle_);
}
#endif // _WIN32 or _WIN64 or __WINDOW__
}
Library(const Library &) = delete; // move-only
Library(Library &&other) noexcept { swap(*this, other); }
Library &operator=(Library other) noexcept {
// Copy-and-swap idiom
swap(*this, other);
return *this;
}
friend void swap(Library &first, Library &second) noexcept;
void *get() noexcept { return handle_; }
const void *get() const noexcept { return handle_; }
/*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
void *ptr = dlsym(handle_, symbol);
NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library");
return ptr;
#endif // _WIN32 or _WIN64 or __WINDOW__
}
private:
void *handle_ = nullptr;
};
void swap(Library &first, Library &second) noexcept {
using std::swap;
swap(first.handle_, second.handle_);
}
/*! \brief Lazily-initialized shared library for CUDA driver */
Library &cuda_driver_lib() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
constexpr char lib_name[] = "nvcuda.dll";
#else
constexpr char lib_name[] = "libcuda.so.1";
#endif
static Library lib(lib_name);
return lib;
}
} // namespace
namespace cuda_driver { namespace cuda_driver {
void *get_symbol(const char *symbol) { void *get_symbol(const char *symbol) {
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "cuda_nvml.h"
#include "shared_lib_wrapper.h"
namespace transformer_engine {
namespace cuda_nvml {
/*! \brief Lazily-initialized shared library for CUDA NVML */
Library &cuda_nvml_lib() {
constexpr char lib_name[] = "libnvidia-ml.so.1";
static Library lib(lib_name);
return lib;
}
void *get_symbol(const char *symbol) { return cuda_nvml_lib().get_symbol(symbol); }
} // namespace cuda_nvml
} // namespace transformer_engine
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
#include <nvml.h>
#include <string>
#include "../common.h"
#include "../util/string.h"
namespace transformer_engine {
namespace cuda_nvml {
/*! \brief Get pointer corresponding to symbol in CUDA NVML library */
void *get_symbol(const char *symbol);
/*! \brief Call function in CUDA NVML library
*
* The CUDA NVML library (libnvidia-ml.so.1 on Linux) may be different at
* compile-time and run-time.
*
* \param[in] symbol Function name
* \param[in] args Function arguments
*/
template <typename... ArgTs>
inline nvmlReturn_t call(const char *symbol, ArgTs... args) {
using FuncT = nvmlReturn_t(ArgTs...);
FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));
return (*func)(args...);
}
/*! \brief Get NVML error string
*
* \param[in] rc NVML return code
*/
inline const char *get_nvml_error_string(nvmlReturn_t rc) {
using FuncT = const char *(nvmlReturn_t);
FuncT *func = reinterpret_cast<FuncT *>(get_symbol("nvmlErrorString"));
return (*func)(rc);
}
} // namespace cuda_nvml
} // namespace transformer_engine
#define NVTE_CHECK_CUDA_NVML(expr) \
do { \
const nvmlReturn_t status_NVTE_CHECK_CUDA_NVML = (expr); \
if (status_NVTE_CHECK_CUDA_NVML != NVML_SUCCESS) { \
const char *desc_NVTE_CHECK_CUDA_NVML = \
::transformer_engine::cuda_nvml::get_nvml_error_string(status_NVTE_CHECK_CUDA_NVML); \
NVTE_ERROR("NVML Error: ", desc_NVTE_CHECK_CUDA_NVML); \
} \
} while (false)
#define VA_ARGS(...) , ##__VA_ARGS__
#define NVTE_CALL_CHECK_CUDA_NVML(symbol, ...) \
do { \
NVTE_CHECK_CUDA_NVML(::transformer_engine::cuda_nvml::call(#symbol VA_ARGS(__VA_ARGS__))); \
} while (false)
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CUDA_NVML_H_
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_
#define TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_
#include <dlfcn.h>
namespace transformer_engine {
/*! \brief Wrapper class for a shared library
*
* \todo Windows support
*/
class Library {
public:
explicit Library(const char *filename) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
handle_ = dlopen(filename, RTLD_LAZY | RTLD_LOCAL);
NVTE_CHECK(handle_ != nullptr, "Lazy library initialization failed");
#endif // _WIN32 or _WIN64 or __WINDOW__
}
~Library() {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
#else
if (handle_ != nullptr) {
dlclose(handle_);
}
#endif // _WIN32 or _WIN64 or __WINDOW__
}
Library(const Library &) = delete; // move-only
void *get() noexcept { return handle_; }
const void *get() const noexcept { return handle_; }
/*! \brief Get pointer corresponding to symbol in shared library */
void *get_symbol(const char *symbol) {
#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
// TODO Windows support
NVTE_ERROR("Shared library initialization is not supported with Windows");
#else
void *ptr = dlsym(handle_, symbol);
NVTE_CHECK(ptr != nullptr, "Could not find symbol in lazily-initialized library");
return ptr;
#endif // _WIN32 or _WIN64 or __WINDOW__
}
private:
void *handle_ = nullptr;
};
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTIL_SHARED_LIB_WRAPPER_H_
...@@ -390,8 +390,7 @@ class CommOverlapHelper : torch::CustomClassHolder { ...@@ -390,8 +390,7 @@ class CommOverlapHelper : torch::CustomClassHolder {
CommOverlapHelper(); CommOverlapHelper();
CommOverlapHelper(c10d::ProcessGroup *world_group, CommOverlapHelper(c10d::ProcessGroup *world_group,
std::optional<c10d::ProcessGroup *> intra_node_group, std::optional<c10d::ProcessGroup *> intra_node_group);
std::optional<c10d::ProcessGroup *> inter_node_group);
~CommOverlapHelper(); ~CommOverlapHelper();
......
...@@ -26,8 +26,7 @@ CommOverlapHelper::CommOverlapHelper() { ...@@ -26,8 +26,7 @@ CommOverlapHelper::CommOverlapHelper() {
} // empty constructor for NVTE_UB_WITH_MPI=1 } // empty constructor for NVTE_UB_WITH_MPI=1
CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group,
std::optional<c10d::ProcessGroup *> intra_domain_group, std::optional<c10d::ProcessGroup *> intra_domain_group) {
std::optional<c10d::ProcessGroup *> inter_domain_group) {
#ifndef NVTE_UB_WITH_MPI #ifndef NVTE_UB_WITH_MPI
pgs.insert({"world", world_group}); pgs.insert({"world", world_group});
myrank = pgs["world"]->getRank(); myrank = pgs["world"]->getRank();
...@@ -53,20 +52,9 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group, ...@@ -53,20 +52,9 @@ CommOverlapHelper::CommOverlapHelper(c10d::ProcessGroup *world_group,
mynode = 0; mynode = 0;
numnodes = 1; numnodes = 1;
} else { } else {
// Intra-node group is different than the world group so there must be multiple nodes
NVTE_CHECK(
inter_domain_group.has_value(),
"Internal TE error: Inter-node group cannot be `None` when intra-node group is not ",
"identical to the world_group!");
// Get node ID and number of nodes // Get node ID and number of nodes
NVTE_CHECK( mynode = myrank / numlocal;
inter_domain_group.value()->getBackendType() == backend, numnodes = numranks / numlocal;
"Internal TE error: Inter-node group must be on the same backend (%s) as the world ",
"group!", pgs["world"]->getBackendName());
pgs.insert({"inter", inter_domain_group.value()});
mynode = pgs["inter"]->getRank();
numnodes = pgs["inter"]->getSize();
} }
} else { } else {
// Intra-node group is not set so we assume there is only 1 node // Intra-node group is not set so we assume there is only 1 node
......
...@@ -285,10 +285,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -285,10 +285,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::class_<CommOverlapHelper>(m, "CommOverlapHelper") py::class_<CommOverlapHelper>(m, "CommOverlapHelper")
.def(py::init<>(), py::call_guard<py::gil_scoped_release>()) .def(py::init<>(), py::call_guard<py::gil_scoped_release>())
.def(py::init<c10d::ProcessGroup *, std::optional<c10d::ProcessGroup *>, .def(py::init<c10d::ProcessGroup *, std::optional<c10d::ProcessGroup *>>(),
std::optional<c10d::ProcessGroup *>>(),
py::call_guard<py::gil_scoped_release>(), py::arg("world_group"), py::call_guard<py::gil_scoped_release>(), py::arg("world_group"),
py::arg("intra_node_group") = py::none(), py::arg("inter_node_group") = py::none()); py::arg("intra_node_group") = py::none());
py::class_<CommOverlap, std::shared_ptr<CommOverlap>, transformer_engine::CommOverlapBase, py::class_<CommOverlap, std::shared_ptr<CommOverlap>, transformer_engine::CommOverlapBase,
transformer_engine::CommOverlapCore>(m, "CommOverlap") transformer_engine::CommOverlapCore>(m, "CommOverlap")
......
...@@ -7,9 +7,6 @@ import io ...@@ -7,9 +7,6 @@ import io
import os import os
import pickle import pickle
import warnings import warnings
import socket
import fcntl
import struct
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
from contextlib import contextmanager from contextlib import contextmanager
...@@ -177,85 +174,32 @@ def initialize_ub( ...@@ -177,85 +174,32 @@ def initialize_ub(
world_rank = torch.distributed.get_rank(world_group) world_rank = torch.distributed.get_rank(world_group)
world_size = torch.distributed.get_world_size(world_group) world_size = torch.distributed.get_world_size(world_group)
# We have single-node NVLink so we can color based on physical node hostnames. num_domains = world_size // tp_size
# NOTE: Prefer a network interface defined via the NVTE_UB_SOCKET_IFNAME variable, and mydomain_idx = world_rank // tp_size
# otherwise fall back on NCCL_SOCKET_IFNAME or GLOO_SOCKET_IFNAME depending on
# the chosen bootstrap backend.
mydomain = socket.gethostname()
ifname = os.getenv(
"NVTE_UB_SOCKET_IFNAME", os.getenv(f"{bootstrap_backend.upper()}_SOCKET_IFNAME")
)
if ifname is not None:
# Make sure the ifname found in the environment is a valid network interface
if ifname in [name for _, name in socket.if_nameindex()]:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
mydomain = socket.inet_ntoa(
fcntl.ioctl(
s.fileno(), 0x8915, struct.pack("256s", ifname[:15].encode("UTF-8"))
)[20:24]
)
except OSError as err:
raise OSError(f"Invalid network interface: {ifname}") from err
finally:
s.close()
else:
ifname_warning = (
f"'{ifname}' is not a valid network interface! `te.initialize_ub()` will"
+ " attempt to detect ranks on the same node by matching "
+ "'socket.gethostname()', which is known to fail on virtual clusters like "
+ "Kubernetes. If Userbuffers initialization fails, please set the "
+ "'NVTE_UB_SOCKET_IFNAME' variable in your environment to the correct network "
+ "interface."
)
warnings.warn(ifname_warning, UserWarning)
# Allgather the domain colors across ranks and reduce to a list of unique domains
domain_per_rank_list = [None for _ in range(world_size)]
torch.distributed.all_gather_object(domain_per_rank_list, mydomain, world_group)
unique_domains = []
for domain in domain_per_rank_list:
if domain not in unique_domains:
unique_domains.append(domain)
num_domains = len(unique_domains)
if num_domains > 1: if num_domains > 1:
# DP/TP model replicated on multiple NVLink domains ranks_per_domain_list = [
ranks_per_domain_list = [[] for _ in range(num_domains)] [i * tp_size + t for t in range(tp_size)] for i in range(num_domains)
mydomain_idx = -1 ]
for i, domain in enumerate(domain_per_rank_list): tp_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
domain_idx = unique_domains.index(domain)
ranks_per_domain_list[domain_idx].append(i)
if domain == mydomain:
mydomain_idx = domain_idx
assert mydomain_idx >= 0, "Internal TE error!"
intra_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
ranks_per_domain_list, backend=bootstrap_backend ranks_per_domain_list, backend=bootstrap_backend
) )
local_rank = torch.distributed.get_rank(intra_domain_group) local_rank = torch.distributed.get_rank(tp_domain_group)
intra_domain_ranks = torch.distributed.get_process_group_ranks(intra_domain_group) tp_domain_ranks = torch.distributed.get_process_group_ranks(tp_domain_group)
inter_domain_group, _ = torch.distributed.new_subgroups_by_enumeration(
[list(ranks) for ranks in zip(*ranks_per_domain_list)],
backend=bootstrap_backend,
)
helper = tex.CommOverlapHelper(world_group, intra_domain_group, inter_domain_group)
helper = tex.CommOverlapHelper(world_group, tp_domain_group)
else: else:
# TP model on single NVLink domain, no replication, no data-parallelism # TP model on single NVLink domain, no replication, no data-parallelism
mydomain_idx = 0 mydomain_idx = 0
local_rank = world_rank local_rank = world_rank
intra_domain_ranks = list(range(world_size)) tp_domain_ranks = list(range(world_size))
helper = tex.CommOverlapHelper(world_group) helper = tex.CommOverlapHelper(world_group)
if world_rank == 0: if world_rank == 0:
print(f"!!! [UB] Number of NVLink domains: {num_domains}\n", end="", flush=True) print(f"!!! [UB] Number of TP domains: {num_domains}\n", end="", flush=True)
if local_rank == 0: if local_rank == 0:
print( print(
f"!!! [UB] Global ranks on domain {mydomain_idx}: {intra_domain_ranks}\n", f"!!! [UB] Global ranks on TP domain {mydomain_idx}: {tp_domain_ranks}\n",
end="", end="",
flush=True, flush=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment