Unverified Commit c3407300 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Move userbuffer to PyTorch (#162)



* Initial refactor; linker error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix linking issue and make mpi conditional
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix TF/JAX build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Use max SMs at the last RS chunk in pipelined overlap
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Make userbuffers support opt-in

Decouple userbuffers from MPI. Refactor MPI handling in build system. Standardize names to "userbuffers".
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Lint
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
parent 4a1efe89
......@@ -9,11 +9,7 @@ set -e
TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2`
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
# Find MPI
MPI_HOME=${MPI_HOME:-/usr/local/mpi}
NVTE_MPI_INCLUDE="$MPI_HOME/lib"
cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild -DNVTE_MPI_INCLUDE=$NVTE_MPI_INCLUDE .
cmake -GNinja -Bbuild .
cmake --build build
ctest --test-dir build -j4
......@@ -21,9 +21,10 @@ with open(path + "/VERSION", "r") as f:
te_version = f.readline()
CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda")
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi")
NVTE_MPI_FOUND = os.path.exists(MPI_HOME)
NVTE_MPI_INCLUDE = os.path.join(MPI_HOME, "include")
NVTE_WITH_USERBUFFERS = int(os.environ.get("NVTE_WITH_USERBUFFERS", "0"))
if NVTE_WITH_USERBUFFERS:
MPI_HOME = os.environ.get("MPI_HOME", "")
assert MPI_HOME, "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
......@@ -70,8 +71,8 @@ def extra_compiler_flags():
"--expt-extended-lambda",
"--use_fast_math",
]
if NVTE_MPI_FOUND:
extra_flags.append("-DNVTE_MPI_FOUND")
if NVTE_WITH_USERBUFFERS:
extra_flags.append("-DNVTE_WITH_USERBUFFERS")
return extra_flags
......@@ -105,8 +106,9 @@ include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
]
if (framework in ("all", "pytorch")) and NVTE_MPI_FOUND:
include_dirs.append(NVTE_MPI_INCLUDE)
if NVTE_WITH_USERBUFFERS:
if MPI_HOME:
include_dirs.append(os.path.join(MPI_HOME, "include"))
include_dirs = make_abs_path(include_dirs)
args = sys.argv.copy()
......@@ -165,9 +167,7 @@ class PyTorchBuilder(FrameworkBuilderBase):
self.pytorch_build_extensions.run()
def cmake_flags(self):
if not NVTE_MPI_FOUND:
return []
return ["-DNVTE_MPI_FOUND=1", f"-DNVTE_MPI_INCLUDE={NVTE_MPI_INCLUDE}"]
return []
@staticmethod
def install_requires():
......@@ -338,6 +338,8 @@ class TEBuildExtension(build_ext, object):
self.dlfw_builder.append(functor(*args, **kwargs))
flags = []
if NVTE_WITH_USERBUFFERS:
flags.append('-DNVTE_WITH_USERBUFFERS=ON')
for builder in self.dlfw_builder:
flags = flags + builder.cmake_flags()
......
......@@ -19,7 +19,7 @@ add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest)
enable_testing()
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
if(NOT DEFINED TE_LIB_PATH)
execute_process(COMMAND bash -c "pip show transformer-engine | grep Location | cut -d ' ' -f 2 | tr -d '\n'"
......@@ -28,11 +28,6 @@ endif()
find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
if(EXISTS ${NVTE_MPI_INCLUDE})
find_library(MPI_LIB NAMES mpi PATHS ${NVTE_MPI_INCLUDE} REQUIRED)
message(STATUS "Found MPI library: ${MPI_LIB}")
endif()
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(${CMAKE_SOURCE_DIR})
......
......@@ -19,10 +19,6 @@ add_executable(test_operator
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB})
if(EXISTS ${NVTE_MPI_INCLUDE})
list(APPEND test_operator_LINKER_LIBS ${MPI_LIB})
endif()
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
target_compile_options(test_operator PRIVATE -O2)
......
......@@ -8,7 +8,6 @@ if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90)
endif()
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
......@@ -26,6 +25,9 @@ find_package(Python COMPONENTS Interpreter Development REQUIRED)
include_directories(${PROJECT_SOURCE_DIR})
add_subdirectory(common)
if(NVTE_WITH_USERBUFFERS)
add_subdirectory(pytorch/csrc/userbuffers)
endif()
option(ENABLE_JAX "Enable JAX in the building workflow." OFF)
if(ENABLE_JAX)
......
......@@ -2,54 +2,42 @@
#
# See LICENSE for license information.
# Configure Transformer Engine library
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES transformer_engine.cpp
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
if(NVTE_MPI_FOUND)
list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp)
endif()
list(APPEND transformer_engine_SOURCES
transformer_engine.cpp
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
transpose/transpose_fusion.cu
transpose/multi_cast_transpose.cu
activation/gelu.cu
gemm/cublaslt_gemm.cu
layer_norm/ln_api.cpp
layer_norm/ln_bwd_semi_cuda_kernel.cu
layer_norm/ln_fwd_cuda_kernel.cu
rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_fwd_cuda_kernel.cu
util/cast.cu
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
if(NVTE_MPI_FOUND)
list(APPEND transformer_engine_LINKER_LIBS gdrapi)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
CUDA::cudart
CUDA::nvToolsExt)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Compiler options
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
if(NVTE_MPI_FOUND)
set_source_files_properties(comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp
PROPERTIES
INCLUDE_DIRECTORIES ${NVTE_MPI_INCLUDE}
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=64>")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
......@@ -37,8 +37,8 @@ def _load_library():
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
def _load_mpi():
"""Load MPI shared library"""
def _load_userbuffers():
"""Load shared library with userbuffers"""
system = platform.system()
if system == "Linux":
......@@ -49,15 +49,14 @@ def _load_mpi():
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libmpi." + extension
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi")
NVTE_MPI_FOUND = os.path.exists(MPI_HOME)
dll_path = os.path.join(MPI_HOME, "lib", lib_name)
lib_name = "libtransformer_engine_userbuffers." + extension
dll_path = get_te_path()
dll_path = os.path.join(dll_path, lib_name)
if NVTE_MPI_FOUND:
if os.path.exists(dll_path):
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
return None
_TE_LIB_CTYPES = _load_mpi()
_TE_LIB_CTYPES = _load_library()
_UB_LIB_CTYPES = _load_userbuffers()
......@@ -14,9 +14,10 @@
#include <torch/custom_class.h>
#include <torch/extension.h>
#include <torch/types.h>
#include <transformer_engine/userbuffers.h>
#include "userbuffers/userbuffers.h"
#define HALF_BYTES 2
#define UB_MAX_SM 32
#define CHECK_CUDA(call) \
do { \
......@@ -174,6 +175,7 @@ struct UbufCommOverlap : torch::CustomClassHolder {
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ubuf_offset = 0;
int ori_sms = _ub_comm->sms;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
......@@ -232,7 +234,8 @@ struct UbufCommOverlap : torch::CustomClassHolder {
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
// Last communication chunk with max SM
_ub_comm->sms = UB_MAX_SM;
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
......@@ -255,7 +258,10 @@ struct UbufCommOverlap : torch::CustomClassHolder {
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
// Communication chunk. Uses MAX_SM at the last chunk
if (i == _num_splits-1) {
_ub_comm->sms = UB_MAX_SM;
}
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
......@@ -264,6 +270,7 @@ struct UbufCommOverlap : torch::CustomClassHolder {
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
_ub_comm->sms = ori_sms;
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
......
......@@ -5,9 +5,9 @@
************************************************************************/
#include "extensions.h"
#ifdef NVTE_MPI_FOUND
#ifdef NVTE_WITH_USERBUFFERS
#include "comm_gemm_overlap.h"
#endif // NVTE_MPI_FOUND
#endif // NVTE_WITH_USERBUFFERS
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
......@@ -1022,7 +1022,7 @@ size_t get_cublasLt_version() {
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_MPI_FOUND
#ifdef NVTE_WITH_USERBUFFERS
return true;
#else
return false;
......@@ -1080,7 +1080,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
#ifdef NVTE_MPI_FOUND
#ifdef NVTE_WITH_USERBUFFERS
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
......@@ -1099,11 +1099,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_MPI_FOUND
#else // NVTE_WITH_USERBUFFERS
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations");
#endif // NVTE_MPI_FOUND
#endif // NVTE_WITH_USERBUFFERS
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# Configure userbuffers library
add_library(transformer_engine_userbuffers SHARED
userbuffers.cu
userbuffers-host.cpp)
target_include_directories(transformer_engine_userbuffers PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}")
# Configure dependencies
find_package(MPI REQUIRED)
find_library(GDRCOPY_LIBRARY gdrapi
HINTS "${GDRCOPY_LIBRARY_DIR}" "$ENV{GDRCOPY_LIBRARY_DIR}")
if(NOT GDRCOPY_LIBRARY)
message(FATAL_ERROR "Could not find GDRCopy, please set GDRCOPY_LIBRARY_DIR")
endif()
message(STATUS "Found GDRCopy: ${GDRCOPY_LIBRARY}")
target_link_libraries(transformer_engine_userbuffers PUBLIC
CUDA::cudart
MPI::MPI_CXX
${GDRCOPY_LIBRARY})
target_include_directories(transformer_engine_userbuffers PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# Compiler options
set_source_files_properties(userbuffers.cu
userbuffers-host.cpp
PROPERTIES
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=64>")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
......@@ -13,12 +13,11 @@
#include <sched.h>
#include <stdio.h>
#include <string.h>
#include <transformer_engine/userbuffers.h>
#include <transformer_engine/logging.h>
#include <unistd.h>
#include <x86intrin.h>
#include <chrono>
#include <iostream>
#include "userbuffers.h"
static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root,
......@@ -48,6 +47,12 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co
} \
} while (0)
#define NVTE_UB_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
" in function " + __func__ + ": " + x); \
} while (false)
int pipe_rank(communicator *comm, int step) {
int mynode = comm->myrank / comm->nvsize;
int mylocal = comm->nvrank;
......@@ -347,7 +352,7 @@ int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, cons
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
......@@ -394,7 +399,7 @@ void allreduce2_userbuff_inplace(const int handler, const int offset, const int
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp);
return;
......@@ -402,7 +407,7 @@ void allreduce_userbuff_inplace(const int handler, const int offset, const int e
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
......@@ -443,7 +448,7 @@ void reducescatter_userbuff_inplace(const int handler, const int offset, const i
void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
if (elements < 64) NVTE_UB_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
......
......@@ -14,7 +14,7 @@
#endif
#include <assert.h>
#include <stdio.h>
#include <transformer_engine/userbuffers.h>
#include "userbuffers.h"
#define MAX_THREADS 1024
#define TIMEOUT 200000000000ull
......
......@@ -8,7 +8,7 @@
#define TRANSFORMER_ENGINE_USERBUFFERS_H_
#include <cuda.h>
#include <mpi.h>
#include <mpi.h> // TODO (tym): Removing will remove PyT extension dependence on MPI
#include "cuda_runtime.h"
#include <pthread.h>
#include <chrono>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment