Unverified Commit e706e5fa authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[C/PyTorch] Removed MPI dependence in Userbuffers (#901)



* added DL framework callbacks for bootstrapping userbuffers without MPI
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed userbuffers availability check in TE modules since userbuffers is now always compiled
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added comm+GEMM overlap example with LayerNormMLP
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* lintin and review fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* linting and review fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added header guards
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed defunct userbuffers checks in build_utils and setup.py
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added exposed API in modules/base.py to __all__
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* removed transformer_engine/CMakeLists.txt and shifted all TE/common compile into transformer_engine/common/CmakeLists.txt
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 7d576ed2
......@@ -11,7 +11,7 @@ import setuptools
from .utils import (
all_files_in_dir,
cuda_version,
userbuffers_enabled,
cuda_path,
)
......@@ -28,6 +28,9 @@ def setup_pytorch_extension(
sources = [
csrc_source_files / "common.cu",
csrc_source_files / "ts_fp8_op.cpp",
csrc_source_files / "userbuffers" / "ipcsocket.cc",
csrc_source_files / "userbuffers" / "userbuffers.cu",
csrc_source_files / "userbuffers" / "userbuffers-host.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
......@@ -37,8 +40,12 @@ def setup_pytorch_extension(
common_header_files / "common" / "include",
csrc_header_files,
]
# Compiler flags
cxx_flags = ["-O3"]
cxx_flags = [
"-O3",
"-fvisibility=hidden",
]
nvcc_flags = [
"-O3",
"-gencode",
......@@ -67,13 +74,19 @@ def setup_pytorch_extension(
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# userbuffers support
if userbuffers_enabled():
if os.getenv("MPI_HOME"):
# Libraries -- PyTorch CUDAExtension links to libcudart.so but not to libcuda.so
cuda_home, _ = cuda_path()
library_dirs = [ cuda_home / "compat" / "lib" ]
libraries = [ "cuda" ]
if os.getenv("UB_MPI_BOOTSTRAP"):
assert os.getenv("MPI_HOME") is not None, \
"MPI_HOME must be set when compiling with UB_MPI_BOOTSTRAP=1"
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DNVTE_WITH_USERBUFFERS")
nvcc_flags.append("-DNVTE_WITH_USERBUFFERS")
cxx_flags.append("-DUB_MPI_BOOTSTRAP")
nvcc_flags.append("-DUB_MPI_BOOTSTRAP")
library_dirs.append(mpi_home / "lib")
libraries.append("mpi")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
......@@ -82,10 +95,12 @@ def setup_pytorch_extension(
return CUDAExtension(
name="transformer_engine_torch",
sources=sources,
include_dirs=include_dirs,
sources=[ str(src) for src in sources ],
include_dirs=[ str(inc) for inc in include_dirs ],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[ str(lib) for lib in libraries ],
library_dirs=[ str(lib_dir) for lib_dir in library_dirs ],
)
......@@ -16,15 +16,6 @@ from subprocess import CalledProcessError
from typing import List, Optional, Tuple
@cache
def userbuffers_enabled() -> bool:
"""Check if userbuffers support is enabled"""
if int(os.getenv("NVTE_WITH_USERBUFFERS", "0")):
assert os.getenv("MPI_HOME"), "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
return True
return False
@cache
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
......
#!/usr/bin/python3
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import sys
import subprocess
import argparse
import torch
import torch.distributed as dist
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
def parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(
description="Test a te.LayerNormMLP module with GEMM+comm overlap via Userbuffers.")
parser.add_argument('-i', "--num-iters", type=int, default=5,
help="Number of dummy 'training' iterations.")
parser.add_argument('-b', "--batch-size", type=int, default=2,
help="Input batch size.")
parser.add_argument('-s', "--seq-length", type=int, default=2048,
help="Input sequence length.")
parser.add_argument('-n', "--num-heads", type=int, default=64,
help="Number of attention heads.")
parser.add_argument('-d', "--head-dim", type=int, default=128,
help="Dimension of each attention head.")
parser.add_argument("--mlp-expansion-factor", type=int, default=4,
help="MLP block intermediate size as a factor of hidden dimension.")
parser.add_argument("--seed", type=int, default=1234,
help="RNG seed.")
parser.add_argument("--fp8", action="store_true", default=False,
help="Enables the te.fp8_autocast() context.")
parser.add_argument("--no-comm-overlap", action="store_true", default=False,
help="Disable the comm+GEMM overlap.")
parser.add_argument('-v', "--verbose", action="store_true", default=False)
return parser.parse_args(argv, namespace)
def train(opts):
WORLD_RANK = int(os.getenv("RANK"))
WORLD_SIZE = int(os.getenv("WORLD_SIZE"))
def dist_print(msg, end='\n', all_ranks=False):
if WORLD_RANK == 0 or all_ranks:
print(f"[RANK-{WORLD_RANK}] {msg}", end=end)
# Seed RNG
torch.cuda.set_device(WORLD_RANK)
torch.manual_seed(opts.seed+WORLD_RANK)
torch.cuda.manual_seed(opts.seed+WORLD_RANK)
# Initialize torch.distributed global process group and get TP group
dist.init_process_group(backend="nccl",
rank=WORLD_RANK,
world_size=WORLD_SIZE,
device_id=torch.device(f'cuda:{WORLD_RANK}'))
tp_group = dist.new_group(backend="nccl")
tp_size = dist.get_world_size(tp_group)
# Intialize userbuffers
ag_cfg = { # Ring-exchange All-Gather overlap for fc1_fprop and fc2_dgrad
'method': 'ring_exchange',
'num_splits' : 8,
'num_sm' : 1,
'set_sm_margin' : False,
}
rs_cfg = { # Reduce-scatter overlap for fc1_dgrad and fc2_fprop
'method': 'ring_exchange',
'num_splits' : 4,
'num_sm' : 1,
'set_sm_margin' : True,
}
hidden_size = opts.num_heads * opts.head_dim
batched_size = opts.seq_length * opts.batch_size
if not opts.no_comm_overlap:
te.initialize_ub(
[batched_size, hidden_size],
tp_group,
use_fp8 = opts.fp8,
dtype = torch.bfloat16,
ub_cfgs = {
'fc1_fprop': ag_cfg,
'fc1_dgrad': rs_cfg,
'fc2_fprop': rs_cfg,
'fc2_dgrad': ag_cfg,
},
)
#
model = te.LayerNormMLP(
hidden_size, opts.mlp_expansion_factor * hidden_size,
params_dtype = torch.bfloat16,
device = 'cuda',
tp_group = tp_group,
tp_size = tp_size,
set_parallel_mode = True,
sequence_parallel = True, # this is required for comm+GEMM overlap
seq_length = opts.seq_length,
micro_batch_size = opts.batch_size,
ub_overlap_rs_dgrad = not opts.no_comm_overlap,
ub_overlap_rs = not opts.no_comm_overlap,
ub_overlap_ag = not opts.no_comm_overlap,
)
# Initialize optimizer with model parameters
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
# Fp8 recipe setup
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=32,
amax_compute_algo="max")
# Start dummy "training" iterations
for i in range(opts.num_iters):
dist_print(f"Iter {i+1}", all_ranks=opts.verbose)
dist_print("|-- Generate random input batch", all_ranks=opts.verbose)
x = torch.rand((opts.seq_length // tp_size, opts.batch_size, hidden_size),
dtype=torch.bfloat16, device='cuda', requires_grad=True)
dist_print("|-- Forward pass", all_ranks=opts.verbose)
with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=tp_group):
y = model(x)
dist_print("|-- Compute loss", all_ranks=opts.verbose)
loss = y.flatten().sum()
dist_print("|-- Backward pass", all_ranks=opts.verbose)
loss.backward()
dist_print("|-- Optimizer step", all_ranks=opts.verbose)
optim.step()
te.destroy_ub()
dist.destroy_process_group()
if __name__ == "__main__":
if "TORCHELASTIC_RUN_ID" in os.environ.keys():
args = parse_args()
train(args)
else:
subprocess.run(
[
'torchrun', f'--nproc-per-node={torch.cuda.device_count()}',
*sys.argv
],
env=os.environ,
check=True
)
os._exit(0)
......@@ -16,7 +16,6 @@ from build_tools.utils import (
found_ninja,
found_pybind11,
remove_dups,
userbuffers_enabled,
get_frameworks,
install_and_import,
)
......@@ -41,21 +40,13 @@ CMakeBuildExtension = get_build_ext(BuildExtension)
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library
Also builds JAX or userbuffers support if needed.
"""
cmake_flags = []
if userbuffers_enabled():
cmake_flags.append("-DNVTE_WITH_USERBUFFERS=ON")
"""Setup CMake extension for common library"""
# Project directory root
root_path = Path(__file__).resolve().parent
return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / Path("transformer_engine"),
cmake_flags=cmake_flags,
cmake_path=root_path / Path("transformer_engine/common"),
cmake_flags=[],
)
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
# Check for cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
include(${CMAKE_SOURCE_DIR}/../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR})
add_subdirectory(common)
if(NVTE_WITH_USERBUFFERS)
message(STATUS "userbuffers support enabled")
add_subdirectory(pytorch/csrc/userbuffers)
endif()
......@@ -2,6 +2,40 @@
#
# See LICENSE for license information.
cmake_minimum_required(VERSION 3.18)
if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()
find_package(CUDAToolkit REQUIRED cublas nvToolsExt)
# Check for cuDNN frontend API
set(CUDNN_FRONTEND_INCLUDE_DIR
"${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/include")
if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}")
message(FATAL_ERROR
"Could not find cuDNN frontend API. "
"Try running 'git submodule update --init --recursive' "
"within the Transformer Engine source.")
endif()
include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake)
find_package(Python COMPONENTS Interpreter Development.Module REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}/..)
# Configure Transformer Engine library
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES
......
......@@ -42,17 +42,5 @@ def _load_library():
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
def _load_userbuffers():
"""Load shared library with userbuffers"""
so_dir = get_te_path() / "transformer_engine"
so_file = so_dir / f"libtransformer_engine_userbuffers.{_get_sys_extension()}"
if so_file.exists():
return ctypes.CDLL(so_file, mode=ctypes.RTLD_GLOBAL)
return None
if "NVTE_PROJECT_BUILDING" not in os.environ:
_TE_LIB_CTYPES = _load_library()
_UB_LIB_CTYPES = _load_userbuffers()
......@@ -37,6 +37,8 @@ from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
......
......@@ -4,6 +4,9 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
#include <stdio.h>
#include <stdlib.h>
......@@ -20,6 +23,7 @@
#include "common/util/logging.h"
#include "common/util/system.h"
#include "userbuffers/userbuffers.h"
#include "extensions.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
......@@ -36,6 +40,66 @@
using namespace torch::indexing;
namespace ubuf {
/*
** Static container for Python callbacks to torch.distributed collectives
*/
static struct TorchCallbacks : torch::CustomClassHolder {
bool initialized{false};
std::unordered_map<void *, at::Tensor> gathered_tensors;
std::function<at::Tensor(at::Tensor&, const std::string &)> allgather;
std::function<void(const std::string &)> barrier;
std::function<void(at::Tensor &)> free;
} torch_callbacks;
/*
** Helper function for setting Python callbacks to torch.distributed collectives.
*/
void set_ubuf_bootstrap_callbacks(
std::function<at::Tensor(at::Tensor&, const std::string &)> allgather,
std::function<void(const std::string &)> barrier,
std::function<void(at::Tensor &)> free
) {
torch_callbacks.allgather = allgather;
torch_callbacks.barrier = barrier;
torch_callbacks.free = free;
torch_callbacks.initialized = true;
}
/*
** Python callback for globaldata = torch.distributed.all_gather(localdata, tp_group).
** This *creates* a new tensor, which Userbuffers later frees with a separate callback.
*/
void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, char *group) {
assert(torch_callbacks.initialized);
auto localtensor = torch::from_blob(
localdata, {static_cast<int64_t>(localbytes / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto globaltensor = torch_callbacks.allgather(localtensor, group);
*globaldata = globaltensor.data_ptr();
torch_callbacks.gathered_tensors[*globaldata] = globaltensor;
}
/*
** Python callback for torch.distributed.barrier(tp_group).
*/
void ub_barrier(char *group) {
assert(torch_callbacks.initialized);
torch_callbacks.barrier(group);
}
/*
** Python callback for freeing up tensors created in the ub_alloc_copy_allgather(...) callback.
*/
void ub_free(void *ptr) {
assert(torch_callbacks.initialized);
auto i = torch_callbacks.gathered_tensors.find(ptr);
if (i == torch_callbacks.gathered_tensors.end())
return;
auto tensor = std::move(i->second);
torch_callbacks.gathered_tensors.erase(i);
torch_callbacks.free(tensor);
}
enum class COMM_TYPE { RS = 0, AG = 1 };
enum class UBOverlapAlgo {
......@@ -74,15 +138,21 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
int use_ce;
bool _atomic_gemm;
UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size,
int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm,
torch::Tensor empty_tensor) {
UbufCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
int num_comm_sm, int comm_cga_size, int num_splits, bool set_sm_margin,
int num_max_streams, bool atomic_gemm, torch::Tensor empty_tensor) {
// Initialize userbuf communicator
if (!comm_created) {
if (rank == 0) {
printf("!!! [UB] Create UbufCommOverlap Communicator\n");
}
create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1);
if (transformer_engine::getenv<bool>("UB_MPI_BOOTSTRAP")) {
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
} else {
create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1,
&ub_alloc_copy_allgather, &ub_barrier, &ub_free,
1, 1, tp_size, 1);
}
comm_created = true;
}
use_ce = 0;
......@@ -349,7 +419,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getCurrentCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (int i = 0; i < _stream_compute.size(); i++) {
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
......@@ -467,7 +537,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase {
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
for (int i = 0; i < _stream_compute.size(); i++) {
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......@@ -552,15 +622,22 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
int cga_size;
bool _atomic_gemm;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm,
int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams,
bool is_reduce_scatter, bool atomic_gemm, torch::Tensor empty_tensor) {
UbufP2PCommOverlap(torch::Tensor sample, int rank, int world_size, int tp_rank, int tp_size,
int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2,
int num_max_streams, bool is_reduce_scatter, bool atomic_gemm,
torch::Tensor empty_tensor) {
// Initialize userbuf communicator
if (!comm_created) {
if (rank == 0) {
printf("!!! [UB] Create UbufP2PCommOverlap Communicator\n");
}
create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1);
if (transformer_engine::getenv<bool>("UB_MPI_BOOTSTRAP")) {
create_communicator_grouped2_mpi(&_ub_comm, 1, 1, tp_size, 1);
} else {
create_communicator_grouped2(&_ub_comm, rank, world_size, tp_rank, tp_size, 1, 1,
&ub_alloc_copy_allgather, &ub_barrier, &ub_free,
1, 1, tp_size, 1);
}
comm_created = true;
}
use_ce = 1;
......@@ -666,7 +743,6 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_ub_comm->cga_size = cga_size;
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0);
const int n = _ubuf.size(0);
const int n_chunk = n / _tp_size;
......@@ -806,7 +882,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) {
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (_aggregate2) {
......@@ -924,7 +1000,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
}
}
}
for (int i = 0; i < _stream_compute.size(); i++) {
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......@@ -953,16 +1029,11 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
_ub_comm->use_ce = use_ce;
_ub_comm->sms = sms;
_ub_comm->cga_size = cga_size;
int k = A.size(1);
int n = B.size(0);
// Get communication and GEMM input chunk sizes
int n_chunk = n / _tp_size;
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int input_b_chunk_bytes = n_chunk * k * B.element_size();
// Get input and workspace data pointers
char *input_b_ptr = reinterpret_cast<char *>(B.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int *counter_ptr = reinterpret_cast<int *>(counter.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
......@@ -1059,7 +1130,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_send, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0));
for (int i = 0; i < _stream_compute.size(); i++) {
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0));
}
......@@ -1115,7 +1186,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options());
torch::sum_out(rs_output, reduce_buf, 0);
}
for (int i = 0; i < _stream_compute.size(); i++) {
for (size_t i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[i]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
......@@ -1170,3 +1241,5 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase {
}; // UbufP2PCommOverlap
} // namespace ubuf
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMM_GEMM_OVERLAP_H_
......@@ -4,6 +4,9 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
#include "common.h"
#include "common/common.h"
......@@ -700,8 +703,6 @@ size_t get_cublasLt_version();
size_t get_cudnn_version();
bool userbuf_comm_available();
void placeholder();
......@@ -786,3 +787,5 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, float wd,
float momentum, float dampening, float lr, bool nesterov, bool first_run,
bool wd_after_momentum, float scale);
#endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_EXTENSIONS_H_
......@@ -17,13 +17,4 @@ size_t get_cudnn_version() {
return cudnnGetVersion();
}
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_WITH_USERBUFFERS
return true;
#else
return false;
#endif
}
void placeholder() {} // TODO(ksivamani) clean this up
......@@ -4,10 +4,10 @@
* See LICENSE for license information.
************************************************************************/
#include <pybind11/functional.h>
#include "../extensions.h"
#ifdef NVTE_WITH_USERBUFFERS
#include "comm_gemm_overlap.h"
#endif // NVTE_WITH_USERBUFFERS
#include "../comm_gemm_overlap.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions
......@@ -203,7 +203,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("get_cudnn_version", &get_cudnn_version, "Get cuDNN version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Support THD format for Context Parallel
m.def("thd_read_half_tensor", &thd_read_half_tensor,
......@@ -246,7 +245,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
#ifdef NVTE_WITH_USERBUFFERS
// comm+GEMM overlap w/ userbuffers
m.def("set_ubuf_bootstrap_callbacks", &ubuf::set_ubuf_bootstrap_callbacks);
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
......@@ -258,7 +259,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int, bool, torch::Tensor>())
.def(py::init<torch::Tensor&, int, int, int, int, int, int, int, bool, int, bool,
torch::Tensor>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv)
......@@ -270,7 +272,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, bool, bool, int, bool, bool, torch::Tensor>())
.def(py::init<torch::Tensor&, int, int, int, int, int, int, bool, bool, int, bool, bool,
torch::Tensor>())
.def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs)
.def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag)
......@@ -281,11 +284,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm)
.def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap)
.def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv);
#else // NVTE_WITH_USERBUFFERS
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations");
#endif // NVTE_WITH_USERBUFFERS
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Configure userbuffers library
add_library(transformer_engine_userbuffers SHARED
userbuffers.cu
userbuffers-host.cpp)
target_include_directories(transformer_engine_userbuffers PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}")
# Configure dependencies
find_package(MPI REQUIRED)
target_link_libraries(transformer_engine_userbuffers PUBLIC
CUDA::cudart
CUDA::cuda_driver
MPI::MPI_CXX
)
target_include_directories(transformer_engine_userbuffers PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Compiler options
set_source_files_properties(userbuffers.cu
userbuffers-host.cpp
PROPERTIES
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=64>")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
# Install library
install(TARGETS transformer_engine_userbuffers DESTINATION .)
......@@ -47,7 +47,7 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash,
cliaddr.sun_family = AF_UNIX;
// Create unique name for the socket.
int len =
size_t len =
snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot bind provided name to socket. Name too large");
......@@ -76,9 +76,8 @@ ncclResult_t ncclIpcSocketInit(ncclIpcSocket *handle, int rank, uint64_t hash,
handle->abortFlag = abortFlag;
// Mark socket as non-blocking
if (handle->abortFlag) {
int flags;
EQCHECK(flags = fcntl(fd, F_GETFL), -1);
SYSCHECK(fcntl(fd, F_SETFL, flags | O_NONBLOCK), "fcntl");
int flags = fcntl(fd, F_GETFL);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
return ncclSuccess;
......@@ -193,7 +192,7 @@ ncclResult_t ncclIpcSocketSendMsg(ncclIpcSocket *handle, void *hdr, int hdrLen,
bzero(&cliaddr, sizeof(cliaddr));
cliaddr.sun_family = AF_UNIX;
int len =
size_t len =
snprintf(temp, NCCL_IPC_SOCKNAME_LEN, NCCL_IPC_SOCKNAME_STR, rank, hash);
if (len > (sizeof(cliaddr.sun_path) - 1)) {
WARN("UDS: Cannot connect to provided name for socket. Name too large");
......
......@@ -4,7 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#include "ipcsocket.cc"
#include "ipcsocket.h"
#include "userbuffers.h"
#include <assert.h>
......@@ -13,29 +12,24 @@
#include <cuda_runtime_api.h>
#include <iostream>
#include <math.h>
#include <mpi.h>
#include <sched.h>
#include <stdio.h>
#include <string.h>
#include <unistd.h>
#define MULTICAST_GB_TOTAL 512
static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root,
(reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
#include <inttypes.h>
static int oob_barrier(void *comm_context) {
MPI_Barrier((reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
#ifdef UB_MPI_BOOTSTRAP
#include <mpi.h>
static MPI_Comm EXT_COMM_WORLD = MPI_COMM_WORLD;
static MPI_Comm EXT_COMM_INTRA;
static MPI_Comm EXT_COMM_INTER;
#else
static char EXT_COMM_WORLD[] = "world";
static char EXT_COMM_INTRA[] = "intra";
static char EXT_COMM_INTER[] = "inter";
#endif
static int oob_gather(void *comm_context, int root, void *sbuf, void *rbuf, int len) {
MPI_Gather(sbuf, len, MPI_BYTE, rbuf, len, MPI_BYTE, root,
(reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
#define MULTICAST_GB_TOTAL 512
int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); }
......@@ -95,18 +89,24 @@ int pipe_rank(communicator *comm, int step) {
return newnode * numlocal + newlocal;
}
int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes) {
int create_communicator_grouped2(communicator **comm,
int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
std::function<void(void**, void*, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier,
std::function<void(void*)> ext_free,
int pipegpus, int pipenodes, int tensorgpus, int tensornodes) {
*comm = reinterpret_cast<communicator *>(malloc(sizeof(communicator)));
int myrank, nranks, cur_dev, ndev;
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
(*comm)->nranks = nranks;
(*comm)->comm_world = EXT_COMM_WORLD;
(*comm)->_alloc_copy_allgather = ext_alloc_copy_allgather;
(*comm)->_barrier = ext_barrier;
(*comm)->_free = ext_free;
(*comm)->nranks = numranks;
(*comm)->myrank = myrank;
(*comm)->free_region = 0;
(*comm)->launch_mode = NVTE_LAUNCH_GPU | NVTE_LAUNCH_CPU;
int cur_dev, ndev;
cudaDeviceProp device_prop;
CUDACHECK(cudaGetDevice(&cur_dev));
CUDACHECK(cudaGetDeviceCount(&ndev));
......@@ -135,34 +135,7 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
sec_timeout, (*comm)->ub_timeout, device_clock);
}
int ret = 0;
// split communicator
char host_name[MPI_MAX_PROCESSOR_NAME];
char(*host_names)[MPI_MAX_PROCESSOR_NAME];
int namelen, bytes, color, my_node, mylocal, numlocal, num_nodes;
int rank = (*comm)->myrank, size = (*comm)->nranks;
MPI_Get_processor_name(host_name, &namelen);
bytes = size * sizeof(char[MPI_MAX_PROCESSOR_NAME]);
host_names = (char(*)[MPI_MAX_PROCESSOR_NAME])malloc(bytes);
strcpy(host_names[rank], host_name); // NOLINT(*)
for (int n = 0; n < size; n++)
MPI_Bcast(&(host_names[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, MPI_COMM_WORLD);
qsort(host_names, size, sizeof(char[MPI_MAX_PROCESSOR_NAME]), stringCmp);
color = 0;
for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n]))
color++;
if (strcmp(host_name, host_names[n]) == 0)
break;
}
free(host_names);
MPI_Comm_split(MPI_COMM_WORLD, color, rank, &(*comm)->comm_intra);
// find intranode numbers and make internode communicator
// figure out mylocal
MPI_Comm_rank((*comm)->comm_intra, &mylocal);
MPI_Comm_size((*comm)->comm_intra, &numlocal);
(*comm)->comm_intra = EXT_COMM_INTRA;
(*comm)->nvrank = mylocal;
(*comm)->nvsize = numlocal;
......@@ -198,7 +171,7 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal)
printf("%d: device used %d[%d] ,resetting device to %d\n", rank, cur_dev, ndev, mylocal);
printf("%d: device used %d[%d] ,resetting device to %d\n", myrank, cur_dev, ndev, mylocal);
CUDACHECK(cudaSetDevice(mylocal));
}
(*comm)->mydev = cur_dev;
......@@ -214,31 +187,22 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
(*comm)->ar2_firstgpu = mylocal - mylocal % tensorgpus;
(*comm)->ar2_nvrank = mylocal - (*comm)->ar2_firstgpu;
// ar2 has step equal to ar_nvsize
int allnodes = nranks / numlocal;
int mynode = myrank / numlocal;
int allnodes = numranks / numlocal;
int nodeid = myrank / numlocal;
int datanodes = allnodes / pipenodes / tensornodes;
int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes);
(*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus);
CUDACHECK(cudaFree(0));
int datanodegroup_id =
myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both
// pipenodes=1 and tensornodes=1
// mpi communicator only needed for SHARP which is always
// allreduce1/data-parallel
MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter);
// different rails from same group are in different subcommunicators
MPI_Comm_size((*comm)->comm_inter, &num_nodes);
MPI_Comm_rank((*comm)->comm_inter, &my_node);
(*comm)->first_node = mynode - my_node;
(*comm)->num_nodes = num_nodes;
(*comm)->my_node = my_node;
(*comm)->comm_inter = EXT_COMM_INTER;
(*comm)->first_node = nodeid - mynode;
(*comm)->num_nodes = numnodes;
(*comm)->my_node = mynode;
(*comm)->num2_nodes = tensornodes;
(*comm)->my2_node = (mynode / datanodes) % tensornodes;
(*comm)->first2_node = mynode - (*comm)->my2_node * datanodes;
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
(*comm)->nblocks = 8;
(*comm)->alignblock = 1024 * 512;
......@@ -262,28 +226,33 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
mcProp.size = mc_maxsize;
(*comm)->mc_maxsize = mc_maxsize;
// Broadcast the a POSIX file descriptor from the local root rank to other local ranks.
// NOTE: This cannot be done via MPI_Bcast or other external comm libraries. They mangle the
// file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use system socket to send/recv the file handle
// without mangling.
int fd;
volatile uint32_t abortFlag = 0;
struct ncclIpcSocket ipcSock = {0};
uint64_t opId = 0xdeadcafeb000 + (*comm)->ar2_firstgpu;
ncclResult_t ret = ncclSuccess;
NCCLCHECK(ncclIpcSocketInit(&ipcSock, (*comm)->ar2_nvrank, (uint64_t)opId, &abortFlag));
MPI_Barrier(MPI_COMM_WORLD);
(*comm)->_barrier((*comm)->comm_world);
if ((*comm)->ar2_nvrank == 0) {
CUCHECK(cuMulticastCreate(&(*comm)->mc_handle, &mcProp));
CUCHECK(cuMemExportToShareableHandle(&fd, (*comm)->mc_handle,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
MPI_Barrier((*comm)->comm_intra);
(*comm)->_barrier((*comm)->comm_intra);
NCCLCHECKGOTO(ncclIpcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error);
}
} else {
for (int i = 0; i < (*comm)->ar2_nvrank; i++)
MPI_Barrier((*comm)->comm_intra);
(*comm)->_barrier((*comm)->comm_intra);
NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, &fd), ret, error);
for (int i = 0; i < (*comm)->ar2_nvsize - (*comm)->ar2_nvrank - 1; i++)
MPI_Barrier((*comm)->comm_intra);
(*comm)->_barrier((*comm)->comm_intra);
CUCHECK(cuMemImportFromShareableHandle(&(*comm)->mc_handle, reinterpret_cast<void *>(fd),
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
}
......@@ -303,7 +272,7 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
CUCHECK(cuMemSetAccess(mc_va, mc_maxsize, &accessDesc, 1));
(*comm)->mc_baseptr = reinterpret_cast<void *>(mc_va);
MPI_Barrier(MPI_COMM_WORLD);
(*comm)->_barrier((*comm)->comm_world);
if (!(*comm)->myrank)
printf("MC initialized succesfully, window size = %ld\n", mc_maxsize);
} else {
......@@ -317,12 +286,10 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs,
LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE,
*comm); // will use handler 0
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm, false);
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
......@@ -336,12 +303,12 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
unsigned int flag = 1;
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
using namespace std;
sched_param param;
pthread_attr_t attr;
pthread_attr_init(&attr);
......@@ -353,7 +320,7 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
if (getenv("NVTE_UBDEBUG"))
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP "
"%dx%d PIPE_ID %d/%d\n",
myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
myrank, numranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id,
pipegpus * pipenodes);
......@@ -361,15 +328,141 @@ int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenode
return 0;
}
int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, pipegpus, pipenodes, 1, 1);
int create_communicator_grouped(communicator **comm,
int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
std::function<void(void**, void*, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier,
std::function<void(void*)> ext_free,
int pipegpus, int pipenodes) {
return create_communicator_grouped2(
comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_alloc_copy_allgather, ext_barrier, ext_free,
pipegpus, pipenodes, 1, 1);
}
int create_communicator(communicator **comm) {
return create_communicator_grouped2(comm, 1, 1, 1, 1);
int create_communicator(communicator **comm,
int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
std::function<void(void**, void*, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier,
std::function<void(void*)> ext_free) {
return create_communicator_grouped2(
comm, myrank, numranks, mylocal, numlocal, mynode, numnodes,
ext_alloc_copy_allgather, ext_barrier, ext_free,
1, 1, 1, 1);
}
int create_communicator_grouped2_mpi(communicator **comm,
int pipegpus, int pipenodes, int tensorgpus, int tensornodes) {
#ifdef UB_MPI_BOOTSTRAP
// get global numbers
int myrank, numranks;
MPI_Comm_rank(EXT_COMM_WORLD, &myrank);
MPI_Comm_size(EXT_COMM_WORLD, &numranks);
// find intranode numbers and make internode communicator
char host_name[MPI_MAX_PROCESSOR_NAME];
char(*host_names)[MPI_MAX_PROCESSOR_NAME];
int namelen, bytes, color;
int rank = (*comm)->myrank, size = (*comm)->nranks;
MPI_Get_processor_name(host_name, &namelen);
bytes = size * sizeof(char[MPI_MAX_PROCESSOR_NAME]);
host_names = (char(*)[MPI_MAX_PROCESSOR_NAME])malloc(bytes);
strcpy(host_names[rank], host_name); // NOLINT(*)
for (int n = 0; n < size; n++)
MPI_Bcast(&(host_names[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, EXT_COMM_WORLD);
qsort(host_names, size, sizeof(char[MPI_MAX_PROCESSOR_NAME]), stringCmp);
color = 0;
for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n]))
color++;
if (strcmp(host_name, host_names[n]) == 0)
break;
}
free(host_names);
int mylocal, numlocal;
MPI_Comm_split(EXT_COMM_WORLD, color, rank, &EXT_COMM_INTRA);
MPI_Comm_rank(EXT_COMM_INTRA, &mylocal);
MPI_Comm_size(EXT_COMM_INTRA, &numlocal);
// find internode numbers and make internode communicator
CUDACHECK(cudaFree(0));
int allnodes = numranks / numlocal;
int datanodes = allnodes / pipenodes / tensornodes;
// data reduction group node belongs, equals 0 for all if both pipenodes=1 and tensornodes=1
int datanodegroup_id = myrank / numlocal / datanodes;
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel
MPI_Comm_split(EXT_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &EXT_COMM_INTER);
// different rails from same group are in different subcommunicators
int mynode, numnodes;
MPI_Comm_size(EXT_COMM_INTER, &numnodes);
MPI_Comm_rank(EXT_COMM_INTER, &mynode);
// finally call the abstracted constructor with MPI info
return create_communicator_grouped2(comm,
myrank, numranks, mylocal, numlocal, mynode, numnodes,
&ub_alloc_copy_allgather, &ub_barrier, &ub_free,
pipegpus, pipenodes, tensorgpus, tensornodes);
#else
NVTE_UB_ERROR(std::string("Bootstrapping Userbuffers with MPI requires ") +
std::string("building Transformer Engine with UB_MPI_BOOTSTRAP=1"));
#endif
}
int create_communicator_grouped_mpi(communicator **comm, int pipegpus, int pipenodes) {
return create_communicator_grouped2_mpi(comm, pipegpus, pipenodes, 1, 1);
}
int create_communicator_mpi(communicator **comm) {
return create_communicator_grouped2_mpi(comm, 1, 1, 1, 1);
}
void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) {
if (comm->mem_dealloc[hndl]) {
cuMemAddressFree(reinterpret_cast<CUdeviceptr>(comm->ucbase_ptr[hndl]),
comm->mem_size[hndl] * comm->nvsize);
for (int rank = 0; rank < comm->nvsize; rank++) {
cuMemRelease(comm->uchandles[hndl][rank]);
}
free(reinterpret_cast<void *>(comm->uchandles[hndl]));
} else {
for (int rank = 0; rank < comm->nvsize; rank++) {
if (rank != comm->nvrank) {
cudaIpcCloseMemHandle(comm->peer_ptr[hndl][rank]);
} else {
comm->peer_ptr[hndl][rank] = nullptr; // remove reference to external buffer
}
}
free(comm->peer_ptr[hndl]);
}
comm->mem_ptr[hndl] = nullptr;
}
cudaFree(reinterpret_cast<void *>(comm->flags));
cudaFree(reinterpret_cast<void *>(comm->recv_id));
cudaFree(reinterpret_cast<void *>(comm->send_id));
if (comm->use_mc) {
cuMemAddressFree(reinterpret_cast<CUdeviceptr>(comm->mc_baseptr), comm->mc_maxsize);
cuMemRelease(comm->mc_handle);
}
if (comm->mem_dealloc[0]) {
cudaFree(comm->gpu_ptrs);
}
free(comm->fifo);
free(comm);
}
void destroy_communicator_mpi(communicator *comm) {
#ifdef UB_MPI_BOOTSTRAP
MPI_Comm_free(comm->comm_inter);
MPI_Comm_free(comm->comm_intra);
destroy_communicator(comm);
#else
NVTE_UB_ERROR(std::string("Communicator is not bootstrapped with MPI and ") +
std::string("can only be deallocated with destroy_communicator()."));
#endif
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
......@@ -379,6 +472,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
size_t aligned_size = bytes;
comm->memflags[hndl] = 0;
comm->mem_dealloc[hndl] = alloc;
if (alloc) {
int nranks = comm->nvsize; // total GPUs in NVLINK domain
......@@ -420,9 +514,14 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
uint64_t opId = 0xdeadcafebeef;
ncclResult_t ret = ncclSuccess;
// All-gather POSIX file descriptors across local ranks.
// NOTE: This cannot be done via MPI_Allgather or other external comm libraries. They mangle
// the file descriptor and prevent cuMemImportFromShareableHandle() from correctly
// interpreting the file. Instead, we use system socket to send/recv the file handle
// without mangling.
NCCLCHECK(ncclIpcSocketInit(&ipcSock, myrank, (uint64_t)opId, &abortFlag));
for (int p = 1; p < nranks; p++) {
MPI_Barrier(comm->comm_intra);
comm->_barrier(comm->comm_intra);
NCCLCHECKGOTO(
ncclIpcSocketSendFd(&ipcSock, peerfd[myrank], (myrank + p) % nranks, (uint64_t)opId), ret,
error);
......@@ -482,18 +581,20 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
} else {
assert(comm->nvsize <= 8);
cudaIpcMemHandle_t *memhndl =
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize)));
CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff));
MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl,
sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++)
if (i != comm->nvrank)
CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*)
memhndl[i], cudaIpcMemLazyEnablePeerAccess));
cudaIpcMemHandle_t memhndl;
CUDACHECK(cudaIpcGetMemHandle(&memhndl, *gpubuff));
cudaIpcMemHandle_t *tmp;
comm->_alloc_copy_allgather(
reinterpret_cast<void **>(&tmp), reinterpret_cast<void *>(&memhndl),
sizeof(cudaIpcMemHandle_t), comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++) {
if (i != comm->nvrank) {
CUDACHECK(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*)
cudaIpcMemLazyEnablePeerAccess));
}
}
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize());
......@@ -502,7 +603,7 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize());
free(memhndl);
comm->_free(tmp);
}
comm->mem_size[hndl] = aligned_size;
......
......@@ -8,11 +8,56 @@
#define TRANSFORMER_ENGINE_USERBUFFERS_H_
#include <cuda.h>
#include <mpi.h> // TODO (tym): Removing will remove PyT extension dependence on MPI
#include "cuda_runtime.h"
#include <cuda_runtime.h>
#include <pthread.h>
#include <chrono>
#include <stdexcept>
#include <functional>
#ifdef UB_MPI_BOOTSTRAP
#include <stdexcept>
#include <mpi.h>
#define UB_MPI_CHECK(expr) \
do { \
const int mpicode = (expr); \
if (mpicode != MPI_SUCCESS) { \
char mpimsg[MPI_MAX_ERROR_STRING]; \
int mpilen; \
MPI_Error_string(mpicode, mpimsg, &mpilen); \
std::vector<char> errmsg(1024); \
snprintf(errmsg.data(), errmsg.size(), "%s:%s in function %s: %s", \
__FILE__, __LINE__, __func__, mpimsg); \
throw std::runtime_error(errmsg.data()); \
} \
} while (false)
typedef MPI_Comm ExtComm;
void ub_alloc_copy_allgather(void **globaldata, void *localdata, size_t localbytes, ExtComm comm) {
int myrank, nranks;
UB_MPI_CHECK(MPI_Comm_rank(comm, &myrank));
UB_MPI_CHECK(MPI_Comm_size(comm, &nranks));
*globaldata = malloc(nranks * localbytes);
UB_MPI_CHECK(MPI_Allgather(localdata,
localbytes,
MPI_BYTE,
*globaldata,
nranks * localbytes,
MPI_BYTE,
comm));
}
void ub_barrier(ExtComm comm) {
UB_MPI_CHECK(MPI_Barrier(comm));
}
void ub_free(void *ptr) {
free(ptr);
}
#else
typedef char* ExtComm;
#endif
#define NVTE_MAX_REGIONS 16
#define NVTE_MAX_SMS 32
......@@ -99,6 +144,7 @@ struct communicator {
CUmemGenericAllocationHandle *uchandles[NVTE_MAX_REGIONS];
void* ucbase_ptr[NVTE_MAX_REGIONS]; // only for cuMem allocated memory
size_t mem_size[NVTE_MAX_REGIONS];
bool mem_dealloc[NVTE_MAX_REGIONS];
void* mc_ptr[NVTE_MAX_REGIONS];
void* mc_baseptr;
......@@ -130,9 +176,18 @@ struct communicator {
int padding2[15];
volatile int tail;
MPI_Request mpihndl[NVTE_MAX_SHARP];
MPI_Comm comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
// Abstract communication callbacks to support external bootstrapping (e.g. DL frameworks)
std::function<void(void**, void*, size_t, ExtComm)> _alloc_copy_allgather;
std::function<void(ExtComm)> _barrier;
std::function<void(void*)> _free;
ExtComm comm_world,
comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS)
#ifdef UB_MPI_BOOTSTRAP
MPI_Request mpihndl[NVTE_MAX_SHARP];
#endif
int *send_id, *recv_id;
int mydev;
uint64_t ub_timeout;
......@@ -142,18 +197,38 @@ typedef struct communicator communicator;
void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream);
void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream);
int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped2(communicator **comm,
int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
std::function<void(void**, void*, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier,
std::function<void(void*)> ext_free,
int pipegpus, int pipenodes, int tensorgpus, int tensornodes);
int create_communicator_grouped(communicator **comm,
int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
std::function<void(void**, void*, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier,
std::function<void(void*)> ext_free,
int pipegpus, int pipenodes);
int create_communicator(communicator **comm,
int myrank, int numranks, int mylocal, int numlocal, int mynode, int numnodes,
std::function<void(void**, void*, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier,
std::function<void(void*)> ext_free);
int create_communicator_grouped2_mpi(communicator **comm,
int pipegpus, int pipenodes, int tensorgpus, int tensornodes);
int create_communicator_grouped_mpi(communicator **comm, int pipegpus, int pipenodes);
int create_communicator_mpi(communicator **comm);
void destroy_communicator(communicator *comm);
int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes);
int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes);
/* creates communicator with
allreduce1 to happen in datagpus x datanodes groups,
allreduce2 to happen in tensorgpus x tensor nodes,
where num_nodes = pipenodes x tensornodes x datanodes
nvlink_size = pipegpus x tensorgpus x datagpus
*/
void destroy_communicator_mpi(communicator *comm);
// int check_user_buffer_registration(void* gpubuff, int bytes, communicator* comm, size_t* offset);
/*
......@@ -167,8 +242,7 @@ int pipe_rank(communicator *comm,
// data-parallel and tensor-parallel position within data and tensor
// groups would be preserved
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm,
bool alloc = false);
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc);
/* returns handler and registers buffers. assumed to be collective i.e. you use same groups and
dont mix buffers for different operations returns -1 if cant register (too many preregistered
regions already) if alloc==true will allocate memory and fill the pointers (required for NVL
......
......@@ -8,3 +8,4 @@ from .linear import Linear
from .layernorm_mlp import LayerNormMLP
from .layernorm import LayerNorm
from .rmsnorm import RMSNorm
from .base import initialize_ub, destroy_ub
......@@ -36,6 +36,8 @@ from ..cpp_extensions import (
from ..constants import dist_group_type
from ..float8_tensor import Float8Tensor
__all__ = ["initialize_ub", "destroy_ub"]
_2X_ACC_FPROP = False
_2X_ACC_DGRAD = True
_2X_ACC_WGRAD = True
......@@ -64,7 +66,7 @@ def get_workspace() -> torch.Tensor:
def initialize_ub(
shape: list,
tp_size: int,
tp_group: dist_group_type,
use_fp8: bool = False,
dtype: torch.dtype = torch.bfloat16,
ub_cfgs: Optional[dict] = None
......@@ -74,6 +76,9 @@ def initialize_ub(
assert _ub_communicators is None, "UB communicators are already initialized."
_ub_communicators = {}
rank_id = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
tp_id = torch.distributed.get_rank(tp_group)
tp_size = torch.distributed.get_world_size(tp_group)
# Increase the workspace by the number of maximum concurrent streams
global _cublas_workspace
......@@ -158,6 +163,8 @@ def initialize_ub(
ub_obj = tex.UbufP2PCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
world_size, # World size
tp_id, # TP id
tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
......@@ -172,6 +179,8 @@ def initialize_ub(
ub_obj = tex.UbufCommOverlap(
sample_buffer, # Sample userbuffer
rank_id, # Rank id
world_size, # World size
tp_id, # TP id
tp_size, # TP size
num_sm, # Number of communication SMs
cga_size, # CGA cluster size
......@@ -183,6 +192,26 @@ def initialize_ub(
)
_ub_communicators[name] = ub_obj
def alloc_copy_allgather_callback(local_data: torch.Tensor, group: str) -> torch.Tensor:
pg = None if group == "world" else tp_group
global_size = local_data.numel() * torch.distributed.get_world_size(pg)
global_data = torch.zeros(global_size, dtype=local_data.dtype, device='cuda')
torch.distributed.all_gather_into_tensor(global_data, local_data.cuda(), group=pg)
return global_data.cpu()
def barrier_callback(group: str) -> None:
pg = None if group == "world" else tp_group
torch.distributed.barrier(group=pg)
def free_callback(data: torch.Tensor) -> None:
data.data = torch.Tensor()
tex.set_ubuf_bootstrap_callbacks(
alloc_copy_allgather_callback,
barrier_callback,
free_callback
)
if ub_cfgs is not None:
for name in dgrad_reduce_scatter_overlap:
if name in ub_cfgs and 'method' in ub_cfgs[name] and ub_cfgs[name]['method'] != 'bulk':
......@@ -235,6 +264,13 @@ def get_ub(name: str):
assert name in _ub_communicators, f"UB for {name} is not registered."
return _ub_communicators[name]
def destroy_ub():
"""Destroy all allocated userbuffer communicators."""
global _ub_communicators
_ub_communicators = None
global layers_atomic_ring_exchange
layers_atomic_ring_exchange = []
class TransformerEngineBaseModule(torch.nn.Module, ABC):
"""Base TE module."""
......
......@@ -865,11 +865,6 @@ class LayerNormLinear(TransformerEngineBaseModule):
assert ub_name is not None, "Userbuffer name [string] is not set."
self.ub_name = ub_name
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]):
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......
......@@ -1290,11 +1290,6 @@ class LayerNormMLP(TransformerEngineBaseModule):
(bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and
self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm())
if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag, ub_overlap_rs_dgrad]):
assert (
tex.userbuf_comm_available()
), "Userbuffer communication backend not available."
if tp_group is None:
self.tp_size = tp_size
if tp_size == 1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment