Unverified Commit c3407300 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Move userbuffer to PyTorch (#162)



* Initial refactor; linker error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix linking issue and make mpi conditional
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix TF/JAX build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use max SMs at the last RS chunk in pipelined overlap
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make userbuffers support opt-in

Decouple userbuffers from MPI. Refactor MPI handling in build system. Standardize names to "userbuffers".
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Lint
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 4a1efe89
...@@ -9,11 +9,7 @@ set -e ...@@ -9,11 +9,7 @@ set -e
TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2` TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2`
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
# Find MPI
MPI_HOME=${MPI_HOME:-/usr/local/mpi}
NVTE_MPI_INCLUDE="$MPI_HOME/lib"
cd $TE_PATH/tests/cpp cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild -DNVTE_MPI_INCLUDE=$NVTE_MPI_INCLUDE . cmake -GNinja -Bbuild .
cmake --build build cmake --build build
ctest --test-dir build -j4 ctest --test-dir build -j4
...@@ -21,9 +21,10 @@ with open(path + "/VERSION", "r") as f: ...@@ -21,9 +21,10 @@ with open(path + "/VERSION", "r") as f:
te_version = f.readline() te_version = f.readline()
CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda") CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda")
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi") NVTE_WITH_USERBUFFERS = int(os.environ.get("NVTE_WITH_USERBUFFERS", "0"))
NVTE_MPI_FOUND = os.path.exists(MPI_HOME) if NVTE_WITH_USERBUFFERS:
NVTE_MPI_INCLUDE = os.path.join(MPI_HOME, "include") MPI_HOME = os.environ.get("MPI_HOME", "")
assert MPI_HOME, "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output(
...@@ -70,8 +71,8 @@ def extra_compiler_flags(): ...@@ -70,8 +71,8 @@ def extra_compiler_flags():
"--expt-extended-lambda", "--expt-extended-lambda",
"--use_fast_math", "--use_fast_math",
] ]
if NVTE_MPI_FOUND: if NVTE_WITH_USERBUFFERS:
extra_flags.append("-DNVTE_MPI_FOUND") extra_flags.append("-DNVTE_WITH_USERBUFFERS")
return extra_flags return extra_flags
...@@ -105,8 +106,9 @@ include_dirs = [ ...@@ -105,8 +106,9 @@ include_dirs = [
"transformer_engine/common/include", "transformer_engine/common/include",
"transformer_engine/pytorch/csrc", "transformer_engine/pytorch/csrc",
] ]
if (framework in ("all", "pytorch")) and NVTE_MPI_FOUND: if NVTE_WITH_USERBUFFERS:
include_dirs.append(NVTE_MPI_INCLUDE) if MPI_HOME:
include_dirs.append(os.path.join(MPI_HOME, "include"))
include_dirs = make_abs_path(include_dirs) include_dirs = make_abs_path(include_dirs)
args = sys.argv.copy() args = sys.argv.copy()
...@@ -165,9 +167,7 @@ class PyTorchBuilder(FrameworkBuilderBase): ...@@ -165,9 +167,7 @@ class PyTorchBuilder(FrameworkBuilderBase):
self.pytorch_build_extensions.run() self.pytorch_build_extensions.run()
def cmake_flags(self): def cmake_flags(self):
if not NVTE_MPI_FOUND:
return [] return []
return ["-DNVTE_MPI_FOUND=1", f"-DNVTE_MPI_INCLUDE={NVTE_MPI_INCLUDE}"]
@staticmethod @staticmethod
def install_requires(): def install_requires():
...@@ -338,6 +338,8 @@ class TEBuildExtension(build_ext, object): ...@@ -338,6 +338,8 @@ class TEBuildExtension(build_ext, object):
self.dlfw_builder.append(functor(*args, **kwargs)) self.dlfw_builder.append(functor(*args, **kwargs))
flags = [] flags = []
if NVTE_WITH_USERBUFFERS:
flags.append('-DNVTE_WITH_USERBUFFERS=ON')
for builder in self.dlfw_builder: for builder in self.dlfw_builder:
flags = flags + builder.cmake_flags() flags = flags + builder.cmake_flags()
......
...@@ -28,11 +28,6 @@ endif() ...@@ -28,11 +28,6 @@ endif()
find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
if(EXISTS ${NVTE_MPI_INCLUDE})
find_library(MPI_LIB NAMES mpi PATHS ${NVTE_MPI_INCLUDE} REQUIRED)
message(STATUS "Found MPI library: ${MPI_LIB}")
endif()
message(STATUS "Found transformer_engine library: ${TE_LIB}") message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common/include)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
......
...@@ -19,10 +19,6 @@ add_executable(test_operator ...@@ -19,10 +19,6 @@ add_executable(test_operator
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB}) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB})
if(EXISTS ${NVTE_MPI_INCLUDE})
list(APPEND test_operator_LINKER_LIBS ${MPI_LIB})
endif()
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS}) target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
target_compile_options(test_operator PRIVATE -O2) target_compile_options(test_operator PRIVATE -O2)
......
...@@ -8,7 +8,6 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) ...@@ -8,7 +8,6 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif() endif()
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD_REQUIRED ON)
...@@ -26,6 +25,9 @@ find_package(Python COMPONENTS Interpreter Development REQUIRED) ...@@ -26,6 +25,9 @@ find_package(Python COMPONENTS Interpreter Development REQUIRED)
include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_SOURCE_DIR})
add_subdirectory(common) add_subdirectory(common)
if(NVTE_WITH_USERBUFFERS)
add_subdirectory(pytorch/csrc/userbuffers)
endif()
option(ENABLE_JAX "Enable JAX in the building workflow." OFF) option(ENABLE_JAX "Enable JAX in the building workflow." OFF)
if(ENABLE_JAX) if(ENABLE_JAX)
......
...@@ -2,8 +2,10 @@ ...@@ -2,8 +2,10 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
# Configure Transformer Engine library
set(transformer_engine_SOURCES) set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES transformer_engine.cpp list(APPEND transformer_engine_SOURCES
transformer_engine.cpp
transpose/cast_transpose.cu transpose/cast_transpose.cu
transpose/transpose.cu transpose/transpose.cu
transpose/cast_transpose_fusion.cu transpose/cast_transpose_fusion.cu
...@@ -20,36 +22,22 @@ list(APPEND transformer_engine_SOURCES transformer_engine.cpp ...@@ -20,36 +22,22 @@ list(APPEND transformer_engine_SOURCES transformer_engine.cpp
util/cast.cu util/cast.cu
fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu) fused_softmax/scaled_upper_triang_masked_softmax.cu)
if(NVTE_MPI_FOUND)
list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") "${CMAKE_CURRENT_SOURCE_DIR}/include")
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt) # Configure dependencies
if(NVTE_MPI_FOUND) target_link_libraries(transformer_engine PUBLIC
list(APPEND transformer_engine_LINKER_LIBS gdrapi) CUDA::cublas
endif() CUDA::cudart
CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) target_include_directories(transformer_engine PRIVATE
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
if(NVTE_MPI_FOUND)
set_source_files_properties(comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp
PROPERTIES
INCLUDE_DIRECTORIES ${NVTE_MPI_INCLUDE}
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=64>")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
...@@ -37,8 +37,8 @@ def _load_library(): ...@@ -37,8 +37,8 @@ def _load_library():
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
def _load_mpi(): def _load_userbuffers():
"""Load MPI shared library""" """Load shared library with userbuffers"""
system = platform.system() system = platform.system()
if system == "Linux": if system == "Linux":
...@@ -49,15 +49,14 @@ def _load_mpi(): ...@@ -49,15 +49,14 @@ def _load_mpi():
extension = "dll" extension = "dll"
else: else:
raise RuntimeError(f"Unsupported operating system ({system})") raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libmpi." + extension lib_name = "libtransformer_engine_userbuffers." + extension
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi") dll_path = get_te_path()
NVTE_MPI_FOUND = os.path.exists(MPI_HOME) dll_path = os.path.join(dll_path, lib_name)
dll_path = os.path.join(MPI_HOME, "lib", lib_name)
if NVTE_MPI_FOUND: if os.path.exists(dll_path):
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
return None return None
_TE_LIB_CTYPES = _load_mpi()
_TE_LIB_CTYPES = _load_library() _TE_LIB_CTYPES = _load_library()
_UB_LIB_CTYPES = _load_userbuffers()
...@@ -14,9 +14,10 @@ ...@@ -14,9 +14,10 @@
#include <torch/custom_class.h> #include <torch/custom_class.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/types.h> #include <torch/types.h>
#include <transformer_engine/userbuffers.h> #include "userbuffers/userbuffers.h"
#define HALF_BYTES 2 #define HALF_BYTES 2
#define UB_MAX_SM 32
#define CHECK_CUDA(call) \ #define CHECK_CUDA(call) \
do { \ do { \
...@@ -174,6 +175,7 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -174,6 +175,7 @@ struct UbufCommOverlap : torch::CustomClassHolder {
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr()); char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ubuf_offset = 0; int ubuf_offset = 0;
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream // Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
...@@ -232,7 +234,8 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -232,7 +234,8 @@ struct UbufCommOverlap : torch::CustomClassHolder {
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id])); cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk // Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m, (_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm); _ub_comm, (cudaStream_t)_stream_comm);
...@@ -255,7 +258,10 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -255,7 +258,10 @@ struct UbufCommOverlap : torch::CustomClassHolder {
(cudaStream_t)_stream_compute[i % _stream_compute.size()])); (cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0)); CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk // Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits-1) {
_ub_comm->sms = UB_MAX_SM;
}
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size, reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm); m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
...@@ -264,6 +270,7 @@ struct UbufCommOverlap : torch::CustomClassHolder { ...@@ -264,6 +270,7 @@ struct UbufCommOverlap : torch::CustomClassHolder {
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size(); output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
} }
} }
_ub_comm->sms = ori_sms;
int last_compute_stream_id = int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size(); (_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA( CHECK_CUDA(
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#ifdef NVTE_MPI_FOUND #ifdef NVTE_WITH_USERBUFFERS
#include "comm_gemm_overlap.h" #include "comm_gemm_overlap.h"
#endif // NVTE_MPI_FOUND #endif // NVTE_WITH_USERBUFFERS
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
...@@ -1022,7 +1022,7 @@ size_t get_cublasLt_version() { ...@@ -1022,7 +1022,7 @@ size_t get_cublasLt_version() {
bool userbuf_comm_available() { // TODO(ksivamani) check on python side bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_MPI_FOUND #ifdef NVTE_WITH_USERBUFFERS
return true; return true;
#else #else
return false; return false;
...@@ -1080,7 +1080,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1080,7 +1080,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
#ifdef NVTE_MPI_FOUND #ifdef NVTE_WITH_USERBUFFERS
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo") py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
...@@ -1099,11 +1099,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -1099,11 +1099,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_MPI_FOUND #else // NVTE_WITH_USERBUFFERS
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations"); m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations");
#endif // NVTE_MPI_FOUND #endif // NVTE_WITH_USERBUFFERS
py::enum_<transformer_engine::DType>(m, "DType", py::module_local()) py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte) .value("kByte", transformer_engine::DType::kByte)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Configure userbuffers library
add_library(transformer_engine_userbuffers SHARED
userbuffers.cu
userbuffers-host.cpp)
target_include_directories(transformer_engine_userbuffers PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}")
# Configure dependencies
find_package(MPI REQUIRED)
find_library(GDRCOPY_LIBRARY gdrapi
HINTS "${GDRCOPY_LIBRARY_DIR}" "$ENV{GDRCOPY_LIBRARY_DIR}")
if(NOT GDRCOPY_LIBRARY)
message(FATAL_ERROR "Could not find GDRCopy, please set GDRCOPY_LIBRARY_DIR")
endif()
message(STATUS "Found GDRCopy: ${GDRCOPY_LIBRARY}")
target_link_libraries(transformer_engine_userbuffers PUBLIC
CUDA::cudart
MPI::MPI_CXX
${GDRCOPY_LIBRARY})
target_include_directories(transformer_engine_userbuffers PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Compiler options
set_source_files_properties(userbuffers.cu
userbuffers-host.cpp
PROPERTIES
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=64>")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
#include <sched.h> #include <sched.h>
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <transformer_engine/userbuffers.h>
#include <transformer_engine/logging.h>
#include <unistd.h> #include <unistd.h>
#include <x86intrin.h> #include <x86intrin.h>
#include <chrono> #include <chrono>
#include <iostream> #include <iostream>
#include "userbuffers.h"
static int oob_bcast(void *comm_context, void *buf, int size, int root) { static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root, MPI_Bcast(buf, size, MPI_BYTE, root,
...@@ -48,6 +47,12 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co ...@@ -48,6 +47,12 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co
} \ } \
} while (0) } while (0)
#define NVTE_UB_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
" in function " + __func__ + ": " + x); \
} while (false)
int pipe_rank(communicator *comm, int step) { int pipe_rank(communicator *comm, int step) {
int mynode = comm->myrank / comm->nvsize; int mynode = comm->myrank / comm->nvsize;
int mylocal = comm->nvrank; int mylocal = comm->nvrank;
...@@ -347,7 +352,7 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons ...@@ -347,7 +352,7 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements, void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) { communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode); // if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2; int blocksize = elements * 2;
...@@ -394,7 +399,7 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int ...@@ -394,7 +399,7 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements, void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream, allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp); userbuffers_allreduceop_nonsharp);
return; return;
...@@ -402,7 +407,7 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e ...@@ -402,7 +407,7 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements, void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp; int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
...@@ -443,7 +448,7 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i ...@@ -443,7 +448,7 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
void allgather_userbuff_inplace(const int handler, const int offset, const int elements, void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) { communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented."); if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp; int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize; const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2; int blocksize = elements * 2;
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#endif #endif
#include <assert.h> #include <assert.h>
#include <stdio.h> #include <stdio.h>
#include <transformer_engine/userbuffers.h> #include "userbuffers.h"
#define MAX_THREADS 1024 #define MAX_THREADS 1024
#define TIMEOUT 200000000000ull #define TIMEOUT 200000000000ull
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#define TRANSFORMER_ENGINE_USERBUFFERS_H_ #define TRANSFORMER_ENGINE_USERBUFFERS_H_
#include <cuda.h> #include <cuda.h>
#include <mpi.h> #include <mpi.h> // TODO (tym): Removing will remove PyT extension dependence on MPI
#include "cuda_runtime.h" #include "cuda_runtime.h"
#include <pthread.h> #include <pthread.h>
#include <chrono> #include <chrono>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment