Unverified Commit 18da4e88 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

TP communication overlap with userbuffers (#147)



* Port initial changes
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* readd FA include for PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-enable sm_70 + cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* LICENSE, cleanup header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* 5k -> 173 errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license and fixes in userbuffers-host
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* final cpp cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* pylinting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix from linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Turn off default async amax reduction (#148)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code path
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup Macros
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix conflict resolution bug
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix gencode flags in setup (#145)

* Fix gencode flags based on cuda version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert append_nvcc_threads change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change overlap config dict error message
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* simplify ub initialization
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix sanity imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cpplint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TensorFlow build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TE macros in public header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* compiles with and w/o MPI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes for python side annotations for conditional compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* link gdrAPI only when MPI found
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix comments for dummy var
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix linking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* load MPI before TE
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add Py side argument checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code and catch silent failures
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix find_lib path for tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 7bb2af35
...@@ -4,11 +4,16 @@ ...@@ -4,11 +4,16 @@
set -e set -e
# Find TE
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2` TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2`
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
# Find MPI
MPI_HOME=${MPI_HOME:-/usr/local/mpi}
NVTE_MPI_INCLUDE="$MPI_HOME/lib"
cd $TE_PATH/tests/cpp cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild . cmake -GNinja -Bbuild -DNVTE_MPI_INCLUDE=$NVTE_MPI_INCLUDE .
cmake --build build cmake --build build
ctest --test-dir build -j4 ctest --test-dir build -j4
...@@ -14,3 +14,4 @@ filter=-build/namespaces ...@@ -14,3 +14,4 @@ filter=-build/namespaces
filter=-readability/todo filter=-readability/todo
filter=-build/header_guard filter=-build/header_guard
filter=-build/include filter=-build/include
filter=-build/c++11
...@@ -14,3 +14,4 @@ filter=-build/namespaces ...@@ -14,3 +14,4 @@ filter=-build/namespaces
filter=-readability/todo filter=-readability/todo
filter=-build/header_guard filter=-build/header_guard
filter=-build/include filter=-build/include
filter=-build/c++11
...@@ -14,3 +14,4 @@ filter=-build/namespaces ...@@ -14,3 +14,4 @@ filter=-build/namespaces
filter=-readability/todo filter=-readability/todo
filter=-build/header_guard filter=-build/header_guard
filter=-build/include filter=-build/include
filter=-build/c++11
...@@ -19,7 +19,11 @@ from distutils.file_util import copy_file ...@@ -19,7 +19,11 @@ from distutils.file_util import copy_file
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
with open(path + "/VERSION", "r") as f: with open(path + "/VERSION", "r") as f:
te_version = f.readline() te_version = f.readline()
CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda") CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda")
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi")
NVTE_MPI_FOUND = os.path.exists(MPI_HOME)
NVTE_MPI_INCLUDE = os.path.join(MPI_HOME, "include")
def get_cuda_bare_metal_version(cuda_dir): def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output( raw_output = subprocess.check_output(
...@@ -51,7 +55,7 @@ def extra_gencodes(cc_flag): ...@@ -51,7 +55,7 @@ def extra_gencodes(cc_flag):
def extra_compiler_flags(): def extra_compiler_flags():
return [ extra_flags = [
"-O3", "-O3",
"-gencode", "-gencode",
"arch=compute_70,code=sm_70", "arch=compute_70,code=sm_70",
...@@ -66,6 +70,9 @@ def extra_compiler_flags(): ...@@ -66,6 +70,9 @@ def extra_compiler_flags():
"--expt-extended-lambda", "--expt-extended-lambda",
"--use_fast_math", "--use_fast_math",
] ]
if NVTE_MPI_FOUND:
extra_flags.append("-DNVTE_MPI_FOUND")
return extra_flags
cc_flag = [] cc_flag = []
...@@ -76,12 +83,6 @@ def make_abs_path(l): ...@@ -76,12 +83,6 @@ def make_abs_path(l):
return [os.path.join(path, p) for p in l] return [os.path.join(path, p) for p in l]
include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
]
include_dirs = make_abs_path(include_dirs)
pytorch_sources = [ pytorch_sources = [
"transformer_engine/pytorch/csrc/extensions.cu", "transformer_engine/pytorch/csrc/extensions.cu",
"transformer_engine/pytorch/csrc/common.cu", "transformer_engine/pytorch/csrc/common.cu",
...@@ -100,6 +101,14 @@ supported_frameworks = { ...@@ -100,6 +101,14 @@ supported_frameworks = {
framework = os.environ.get("NVTE_FRAMEWORK", "pytorch") framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
]
if (framework in ("all", "pytorch")) and NVTE_MPI_FOUND:
include_dirs.append(NVTE_MPI_INCLUDE)
include_dirs = make_abs_path(include_dirs)
args = sys.argv.copy() args = sys.argv.copy()
for s in args: for s in args:
if s.startswith("--framework="): if s.startswith("--framework="):
...@@ -155,10 +164,16 @@ class PyTorchBuilder(FrameworkBuilderBase): ...@@ -155,10 +164,16 @@ class PyTorchBuilder(FrameworkBuilderBase):
print("Building pyTorch extensions!") print("Building pyTorch extensions!")
self.pytorch_build_extensions.run() self.pytorch_build_extensions.run()
def cmake_flags(self):
if not NVTE_MPI_FOUND:
return []
return ["-DNVTE_MPI_FOUND=1", f"-DNVTE_MPI_INCLUDE={NVTE_MPI_INCLUDE}"]
@staticmethod @staticmethod
def install_requires(): def install_requires():
return ["flash-attn>=1.0.2",] return ["flash-attn>=1.0.2",]
class TensorFlowBuilder(FrameworkBuilderBase): class TensorFlowBuilder(FrameworkBuilderBase):
def cmake_flags(self): def cmake_flags(self):
p = [d for d in sys.path if 'dist-packages' in d][0] p = [d for d in sys.path if 'dist-packages' in d][0]
...@@ -167,6 +182,7 @@ class TensorFlowBuilder(FrameworkBuilderBase): ...@@ -167,6 +182,7 @@ class TensorFlowBuilder(FrameworkBuilderBase):
def run(self, extensions): def run(self, extensions):
print("Building TensorFlow extensions!") print("Building TensorFlow extensions!")
class JaxBuilder(FrameworkBuilderBase): class JaxBuilder(FrameworkBuilderBase):
def cmake_flags(self): def cmake_flags(self):
p = [d for d in sys.path if 'dist-packages' in d][0] p = [d for d in sys.path if 'dist-packages' in d][0]
......
...@@ -27,6 +27,12 @@ if(NOT DEFINED TE_LIB_PATH) ...@@ -27,6 +27,12 @@ if(NOT DEFINED TE_LIB_PATH)
endif() endif()
find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
if(EXISTS ${NVTE_MPI_INCLUDE})
find_library(MPI_LIB NAMES mpi PATHS ${NVTE_MPI_INCLUDE} REQUIRED)
message(STATUS "Found MPI library: ${MPI_LIB}")
endif()
message(STATUS "Found transformer_engine library: ${TE_LIB}") message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common/include)
include_directories(${CMAKE_SOURCE_DIR}) include_directories(${CMAKE_SOURCE_DIR})
......
...@@ -17,7 +17,13 @@ add_executable(test_operator ...@@ -17,7 +17,13 @@ add_executable(test_operator
test_multi_cast_transpose.cu test_multi_cast_transpose.cu
../test_common.cu) ../test_common.cu)
target_link_libraries(test_operator PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB}) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB})
if(EXISTS ${NVTE_MPI_INCLUDE})
list(APPEND test_operator_LINKER_LIBS ${MPI_LIB})
endif()
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
target_compile_options(test_operator PRIVATE -O2) target_compile_options(test_operator PRIVATE -O2)
include(GoogleTest) include(GoogleTest)
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
"""Top level package""" """Top level package"""
from . import common from . import common
try: try:
from . import pytorch from . import pytorch
except ImportError as e: except ImportError as e:
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
add_library(transformer_engine SHARED
transformer_engine.cpp set(transformer_engine_SOURCES)
transpose/cast_transpose.cu list(APPEND transformer_engine_SOURCES transformer_engine.cpp
transpose/transpose.cu transpose/cast_transpose.cu
transpose/cast_transpose_fusion.cu transpose/transpose.cu
transpose/transpose_fusion.cu transpose/cast_transpose_fusion.cu
transpose/multi_cast_transpose.cu transpose/transpose_fusion.cu
activation/gelu.cu transpose/multi_cast_transpose.cu
gemm/cublaslt_gemm.cu activation/gelu.cu
layer_norm/ln_api.cpp gemm/cublaslt_gemm.cu
layer_norm/ln_bwd_semi_cuda_kernel.cu layer_norm/ln_api.cpp
layer_norm/ln_fwd_cuda_kernel.cu layer_norm/ln_bwd_semi_cuda_kernel.cu
rmsnorm/rmsnorm_api.cpp layer_norm/ln_fwd_cuda_kernel.cu
rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu rmsnorm/rmsnorm_api.cpp
rmsnorm/rmsnorm_fwd_cuda_kernel.cu rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu
util/cast.cu rmsnorm/rmsnorm_fwd_cuda_kernel.cu
fused_softmax/scaled_masked_softmax.cu util/cast.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu) fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
if(NVTE_MPI_FOUND)
list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt) list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS}) if(NVTE_MPI_FOUND)
list(APPEND transformer_engine_LINKER_LIBS gdrapi)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu
PROPERTIES PROPERTIES
COMPILE_OPTIONS "--use_fast_math") COMPILE_OPTIONS "--use_fast_math")
if(NVTE_MPI_FOUND)
set_source_files_properties(comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp
PROPERTIES
INCLUDE_DIRECTORIES ${NVTE_MPI_INCLUDE}
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=64>")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
...@@ -37,4 +37,27 @@ def _load_library(): ...@@ -37,4 +37,27 @@ def _load_library():
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
def _load_mpi():
"""Load MPI shared library"""
system = platform.system()
if system == "Linux":
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libmpi." + extension
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi")
NVTE_MPI_FOUND = os.path.exists(MPI_HOME)
dll_path = os.path.join(MPI_HOME, "lib", lib_name)
if NVTE_MPI_FOUND:
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
return None
_TE_LIB_CTYPES = _load_mpi()
_TE_LIB_CTYPES = _load_library() _TE_LIB_CTYPES = _load_library()
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <immintrin.h>
#include <math.h>
#include <mpi.h>
#include <sched.h>
#include <stdio.h>
#include <string.h>
#include <transformer_engine/userbuffers.h>
#include <transformer_engine/logging.h>
#include <unistd.h>
#include <x86intrin.h>
#include <chrono>
#include <iostream>
static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root,
(reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
static int oob_barrier(void *comm_context) {
MPI_Barrier((reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
static int oob_gather(void *comm_context, int root, void *sbuf, void *rbuf, int len) {
MPI_Gather(sbuf, len, MPI_BYTE, rbuf, len, MPI_BYTE, root,
(reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); }
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
int pipe_rank(communicator *comm, int step) {
int mynode = comm->myrank / comm->nvsize;
int mylocal = comm->nvrank;
int numlocal = comm->nvsize;
int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize;
int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal;
int newnode = mynode;
newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes;
int allnodes = comm->nranks / comm->nvsize;
newnode = (allnodes + (newnode % allnodes)) % allnodes;
return newnode * numlocal + newlocal;
}
int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes) {
*comm = reinterpret_cast<communicator *>(malloc(sizeof(communicator)));
int myrank, nranks, cur_dev, ndev;
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
(*comm)->nranks = nranks;
(*comm)->myrank = myrank;
(*comm)->free_region = 0;
(*comm)->launch_mode = NVTE_LAUNCH_GPU | NVTE_LAUNCH_CPU;
cudaDeviceProp device_prop;
CUDACHECK(cudaGetDevice(&cur_dev));
CUDACHECK(cudaGetDeviceCount(&ndev));
CUDACHECK(cudaGetDeviceProperties(&device_prop, cur_dev));
(*comm)->sm_arch = device_prop.major;
// (*comm)->use_rr_kernel = device_prop.major == 8;
(*comm)->use_rr_kernel = 0;
(*comm)->push = 1;
(*comm)->use_ce = 0;
(*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0;
(*comm)->head = 0;
(*comm)->tail = 0;
(*comm)->activeproxy = 1;
(*comm)->active_nreqs = 0;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1;
int ret = 0;
// split communicator
char host_name[MPI_MAX_PROCESSOR_NAME];
char(*host_names)[MPI_MAX_PROCESSOR_NAME];
int namelen, bytes, color, my_node, mylocal, numlocal, num_nodes;
int rank = (*comm)->myrank, size = (*comm)->nranks;
MPI_Get_processor_name(host_name, &namelen);
bytes = size * sizeof(char[MPI_MAX_PROCESSOR_NAME]);
host_names = (char(*)[MPI_MAX_PROCESSOR_NAME])malloc(bytes);
strcpy(host_names[rank], host_name); // NOLINT(*)
for (int n = 0; n < size; n++)
MPI_Bcast(&(host_names[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, MPI_COMM_WORLD);
qsort(host_names, size, sizeof(char[MPI_MAX_PROCESSOR_NAME]), stringCmp);
color = 0;
for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++;
if (strcmp(host_name, host_names[n]) == 0) break;
}
free(host_names);
MPI_Comm_split(MPI_COMM_WORLD, color, rank, &(*comm)->comm_intra);
// find intranode numbers and make internode communicator
// figure out mylocal
MPI_Comm_rank((*comm)->comm_intra, &mylocal);
MPI_Comm_size((*comm)->comm_intra, &numlocal);
(*comm)->nvrank = mylocal;
(*comm)->nvsize = numlocal;
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
int core;
if (mylocal == 0) core = 50;
if (mylocal == 1) core = 58;
if (mylocal == 2) core = 18;
if (mylocal == 3) core = 26;
if (mylocal == 4) core = 114;
if (mylocal == 5) core = 122;
if (mylocal == 6) core = 82;
if (mylocal == 7) core = 90;
CPU_SET(core, &cpuset);
if (!getenv("NVTE_NODOUBLE")) {
if (core > 128)
CPU_SET(core - 128, &cpuset);
else
CPU_SET(core + 128, &cpuset);
}
if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal)
printf("%d: device used %d[%d] ,resetting device to %d\n", rank, cur_dev, ndev, mylocal);
CUDACHECK(cudaSetDevice(mylocal));
}
(*comm)->mydev = cur_dev;
// FIXME need to check that numlocal is multiple of pipegpus x tensorgpus
// ar1 is data
int divgpus = pipegpus * tensorgpus;
int datagpus = numlocal / divgpus;
(*comm)->ar_nvsize = datagpus;
(*comm)->ar_firstgpu = mylocal - ((mylocal / tensorgpus) % datagpus) * tensorgpus;
(*comm)->ar_nvrank = (mylocal - (*comm)->ar_firstgpu) / tensorgpus;
// ar2 is tensor
(*comm)->ar2_nvsize = tensorgpus;
(*comm)->ar2_firstgpu = mylocal - mylocal % tensorgpus;
(*comm)->ar2_nvrank = mylocal - (*comm)->ar2_firstgpu;
// ar2 has step equal to ar_nvsize
int allnodes = nranks / numlocal;
int mynode = myrank / numlocal;
int datanodes = allnodes / pipenodes / tensornodes;
int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes);
(*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus);
CUDACHECK(cudaFree(0));
int datanodegroup_id =
myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both
// pipenodes=1 and tensornodes=1
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel
MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter);
// different rails from same group are in different subcommunicators
MPI_Comm_size((*comm)->comm_inter, &num_nodes);
MPI_Comm_rank((*comm)->comm_inter, &my_node);
(*comm)->first_node = mynode - my_node;
(*comm)->num_nodes = num_nodes;
(*comm)->my_node = my_node;
(*comm)->num2_nodes = tensornodes;
(*comm)->my2_node = (mynode / datanodes) % tensornodes;
(*comm)->first2_node = mynode - (*comm)->my2_node * datanodes;
char *ib_dev_list;
int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0;
int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0;
if (ZIONROCE) ROCE = 1;
int DGX_H100 = device_prop.major == 9;
switch (mylocal) {
case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*)
case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*)
case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*)
case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*)
case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*)
case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*)
case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*)
case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*)
default: break;
}
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
(*comm)->nblocks = 8;
(*comm)->alignblock = 1024 * 512;
(*comm)->minblock = 1024 * 2 * 1024;
(*comm)->asyncblocks = 16;
CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*)
(NVTE_MAX_SMS + 100) * sizeof(int)));
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0;
_mm_mfence();
sleep(1);
// init_p2p_transport();
(*comm)->ibnvsize = (*comm)->nvsize;
#define NBUF 2
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
(*comm)->sms = 16;
(*comm)->threads = 1024;
#define GPU_PAGE_SHIFT 16
#define GPU_PAGE_SIZE (1UL << GPU_PAGE_SHIFT)
#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1)
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
unsigned int flag = 1;
// cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags);
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
using namespace std;
(*comm)->g = gdr_open();
if ((*comm)->g == NULL) {
fprintf(stderr, "gdrcopy open failed\n");
return -1;
}
gdr_mh_t mh;
ret = gdr_pin_buffer((*comm)->g, (CUdeviceptr)(*comm)->flags, GPU_PAGE_SIZE, 0, 0, &mh);
if (ret) {
fprintf(stderr, "gdr_pin_buffer failed\n");
return -1;
}
ret = gdr_map((*comm)->g, mh, (void **)&((*comm)->map_flags), GPU_PAGE_SIZE); // NOLINT(*)
if (ret) {
fprintf(stderr, "gdr_map failed\n");
return -1;
}
sched_param param;
pthread_attr_t attr;
pthread_attr_init(&attr);
pthread_attr_getschedparam(&attr, &param);
param.sched_priority = sched_get_priority_max(SCHED_FIFO);
pthread_attr_setschedparam(&attr, &param);
if (getenv("NVTE_UBDEBUG"))
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n",
myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id,
pipegpus * pipenodes);
fflush(NULL);
return 0;
}
int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, pipegpus, pipenodes, 1, 1);
}
int create_communicator(communicator **comm) {
return create_communicator_grouped2(comm, 1, 1, 1, 1);
}
void destroy_communicator(communicator *comm) {
comm->activeproxy = 0;
if (!comm->myrank && getenv("NVTE_UBDEBUG"))
printf("waiting for userbuffers proxy thread to exit()\n");
gdr_close(comm->g);
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
int hndl = comm->free_region;
// printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL);
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
if (alloc) {
CUDACHECK(cudaMalloc(gpubuff, bytes));
}
assert(comm->nvsize <= 8);
cudaIpcMemHandle_t *memhndl =
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize)));
CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff));
MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl,
sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++)
if (i != comm->nvrank)
CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*)
memhndl[i], cudaIpcMemLazyEnablePeerAccess));
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(
cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize());
free(memhndl);
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
}
int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements,
const int blocksize, communicator *comm, cudaStream_t stream);
int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op);
int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op);
int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op);
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
// if(maxcredit>4) maxcredit=4;
// if(maxcredit>4 && ar_nvsize==1) maxcredit=4;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
// blocksize=elements*2;
int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
comm->fifo[comm->head].maxcredit = maxcredit;
comm->fifo[comm->head].handler = handler;
comm->fifo[comm->head].offset = offset;
comm->fifo[comm->head].elements = elements;
int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1);
while (newhead == comm->tail) {
}
comm->head = newhead;
comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize;
}
}
void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp2);
}
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp);
return;
}
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize,
comm, stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
comm->fifo[comm->head].maxcredit = maxcredit;
comm->fifo[comm->head].handler = handler;
comm->fifo[comm->head].offset = offset;
comm->fifo[comm->head].elements = elements;
int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1);
while (newhead == comm->tail) {
}
comm->head = newhead;
comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize;
}
}
void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
}
This diff is collapsed.
...@@ -49,6 +49,7 @@ void cublas_gemm(const Tensor *inputA, ...@@ -49,6 +49,7 @@ void cublas_gemm(const Tensor *inputA,
size_t workspaceSize, size_t workspaceSize,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream cudaStream_t stream
) { ) {
void *A = inputA->data.dptr; void *A = inputA->data.dptr;
...@@ -124,6 +125,13 @@ void cublas_gemm(const Tensor *inputA, ...@@ -124,6 +125,13 @@ void cublas_gemm(const Tensor *inputA,
&transa, sizeof(transa))); &transa, sizeof(transa)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb))); &transb, sizeof(transb)));
// Set math SM count
if (math_sm_count != 0) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
&math_sm_count, sizeof(math_sm_count)));
}
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate // set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need // Note: gelu fusion isn't available right now, and we don't need
...@@ -227,6 +235,7 @@ void cublas_gemm(const Tensor *inputA, ...@@ -227,6 +235,7 @@ void cublas_gemm(const Tensor *inputA,
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms"); if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C // D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, NVTE_CHECK_CUBLAS(cublasLtMatmul(handle,
operationDesc, operationDesc,
static_cast<const void*>(&one), /* alpha */ static_cast<const void*>(&one), /* alpha */
...@@ -266,6 +275,7 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -266,6 +275,7 @@ void nvte_cublas_gemm(const NVTETensor A,
NVTETensor workspace, NVTETensor workspace,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm); NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine; using namespace transformer_engine;
...@@ -308,5 +318,6 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -308,5 +318,6 @@ void nvte_cublas_gemm(const NVTETensor A,
grad, wspace->data.dptr, grad, wspace->data.dptr,
wspace->data.shape[0], wspace->data.shape[0],
accumulate, use_split_accumulator, accumulate, use_split_accumulator,
math_sm_count,
stream); stream);
} }
...@@ -36,6 +36,7 @@ extern "C" { ...@@ -36,6 +36,7 @@ extern "C" {
* \param[out] workspace Workspace tensor. * \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix. * \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation. * \param[in] stream CUDA stream used for the operation.
*/ */
void nvte_cublas_gemm(const NVTETensor A, void nvte_cublas_gemm(const NVTETensor A,
...@@ -49,6 +50,7 @@ void nvte_cublas_gemm(const NVTETensor A, ...@@ -49,6 +50,7 @@ void nvte_cublas_gemm(const NVTETensor A,
NVTETensor workspace, NVTETensor workspace,
bool accumulate, bool accumulate,
bool use_split_accumulator, bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream cudaStream_t stream
); );
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_USERBUFFERS_H_
#define TRANSFORMER_ENGINE_USERBUFFERS_H_
#include <cuda.h>
#include <mpi.h>
#include "cuda_runtime.h"
#include <pthread.h>
#include <chrono>
#include "gdrapi.h"
#include <stdexcept>
#define NVTE_MAX_REGIONS 16
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192
#define NVTE_MAX_REQUESTS 1024
#define NVTE_LAUNCH_GPU 1
#define NVTE_LAUNCH_CPU 2
#define NVTE_MAX_NVLINK 8
// region 0 flag offsets
#define NVTE_REG0_OPFLAGS 1024
#define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types)
#define NVTE_REG0_SINGLENODE (2 * NVTE_MAX_NVLINK * NVTE_MAX_SMS + NVTE_MAX_OPS)
#define NVTE_REG0_OFFSET(comm) ((2 * NVTE_MAX_REGIONS) * NVTE_MAX_NVLINK \
+ NVTE_REG0_SINGLENODE * 2 + NVTE_MAX_PEERS)
#define NVTE_REG0_COMMBUFFER 0
#define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS)
#define NVTE_REG0_IBRS 32
#define NVTE_REG0_IBAG 512
#undef NVTE_REG0_COMMBUFFER
#define NVTE_REG0_COMMBUFFER (1024 * 1024 * 16)
// gpuflags map offsets
#define NVTE_GF_STATE 16000
#define NVTE_GF_IBSHARPDONE 0
#define NVTE_HF_NVRSDONE (userbuffers_op_types + 1)
#define NVTE_HF_NVREDUCEDONE (userbuffers_op_types + 3)
#define NVTE_MAX_SHARP 16
typedef struct ub_request {
int optype;
int blocksize;
int basecounter;
int elements;
int handler;
int handler2;
size_t offset;
size_t offset2;
int peer;
// ----execution states
int active, maxcredit;
int nblock, numblocks, unconfirmed_ib_in_flight;
} ub_request;
enum req_type {
userbuffers_allreduceop_sharp,
userbuffers_sendop,
userbuffers_allreduceop_nonsharp,
userbuffers_allreduceop_nonsharp2,
userbuffers_alltoall,
userbuffers_op_types
};
struct communicator {
int myrank, nranks; // global job communicator
int nvrank, nvsize; // single node comm_intra
int free_region;
int launch_mode;
void *gpu_ptrs;
int sms, threads;
int use_rr_kernel; // Whether to use RR (or RW) for NVLink-only kernel
int cga_size;
int push, use_ce;
void *mem_ptr[NVTE_MAX_REGIONS];
void **peer_ptr[NVTE_MAX_REGIONS];
int ar_nvsize, ar_firstgpu,
ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup
// (_splitar init used) would be equal to (nvsize,0) for regular comm_create
int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step
int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size)
int sm_arch;
int num_nodes, my_node,
first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes)
int num2_nodes, my2_node, first2_node; // with num_nodes as a stride
// max value for running block counters in hostflags
int basecounter[userbuffers_op_types]; // NOLINT(*)
int *hostflags;
int *flags, *map_flags;
gdr_t g;
struct sharp_coll_context *sharp_coll_context;
struct sharp_coll_comm *sharp_coll_comm;
void *mem_mr[NVTE_MAX_REGIONS];
ub_request *fifo;
volatile int activeproxy;
int nblocks, alignblock, minblock, asyncblocks, active_nreqs;
ub_request active_req[userbuffers_op_types]; // NOLINT(*)
int padding[7];
volatile int head;
int padding2[15];
volatile int tail;
MPI_Request mpihndl[NVTE_MAX_SHARP];
MPI_Comm comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS)
int ibnvsize; // can be used to fake smaller or larger nvlink domain to use ib instead of nvlink
// or force MNNVL
int *send_id, *recv_id;
int mydev;
};
typedef struct communicator communicator;
int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes);
int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes);
/* creates communicator with
allreduce1 to happen in datagpus x datanodes groups,
allreduce2 to happen in tensorgpus x tensor nodes,
where num_nodes = pipenodes x tensornodes x datanodes
nvlink_size = pipegpus x tensorgpus x datagpus
*/
// int check_user_buffer_registration(void* gpubuff, int bytes, communicator* comm, size_t* offset);
/*
local calls, doesnt communicate between peers
returns handler if buffer is registered already, or -1 if not.
returned offset is offset of gpubuff relative to buffer registered
*/
int pipe_rank(communicator *comm,
int step); // helper function to help walk across allreduce1 x allreduce2 groups
// data-parallel and tensor-parallel position within data and tensor
// groups would be preserved
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm,
bool alloc = false);
/* returns handler and registers buffers. assumed to be collective i.e. you use same groups and
dont mix buffers for different operations returns -1 if cant register (too many preregistered
regions already) if alloc==true will allocate memory and fill the pointers (required for NVL
SHARP and NSO/MNNVL)
*/
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
// for DP distributed optimizer, only nonSHARP multinode is implemented & calls must come in pairs
// ordered
void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
// for TP-parallelism, only single node is implemented
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements,
communicator *comm, const int slice_id, const int nslices,
cudaStream_t stream = 0);
/*
each Rank input is
allgather2_userbuff_inplace: offset+myrank*elements
allgather2_userbuff_inplace_sliced: offset+myrank*elements*nslices+slice_id*elements
equivalent codes would be:
for(int slice=0;slice<ncslices;slice++)
allgather2_userbuff_inplace_sliced(hndl,offset, elements,comm,slice,nslices,stream);
and
allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream);
*/
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream = 0);
/* everything should be 16byte aligned = 8 elts aligned
output is strided: row starts separated by stride elements*/
/* inplace allreduce: works only with buffers registered by previous call. offset should be same
* for all peers */
// two matching pairs, intended to work as push from sender or pull by receiver
// either way signal is a write by sender meaning
// push model: data arrived and visible at receiver(barrier enforced)
// pull model: data ready to be pulled by receiver(no barrier needed)
void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream = 0);
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream = 0);
// alltoall split send and recv to allow for overlap
// send kicks in sending data to the destination - invoke on same stream as data generation
// recv returns once data has received
// send and recv can be on different streams
void userbuffers_alltoall_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
cudaStream_t stream = 0);
void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0);
// void unregister_user_buffer(int handler);
void destroy_communicator(communicator *comm);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
...@@ -267,7 +267,7 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque ...@@ -267,7 +267,7 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(), nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(),
null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N, null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false, (desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false,
desc.use_split_accumulator, stream); desc.use_split_accumulator, 0, stream);
} }
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void *input, void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void *input,
......
...@@ -29,6 +29,9 @@ def fp8_gemm( ...@@ -29,6 +29,9 @@ def fp8_gemm(
use_bias: bool = False, use_bias: bool = False,
use_split_accumulator: bool = False, use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None, D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
extra_output_tensor: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs.""" """TN layout GEMM with fp8 inputs."""
...@@ -55,7 +58,7 @@ def fp8_gemm( ...@@ -55,7 +58,7 @@ def fp8_gemm(
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
_ = torch.ops.tex_ts.te_gemm_ts( args = (
A, A,
A_scale_inv, A_scale_inv,
A_fp8_tensor, A_fp8_tensor,
...@@ -77,8 +80,29 @@ def fp8_gemm( ...@@ -77,8 +80,29 @@ def fp8_gemm(
workspace, workspace,
workspace.shape[0], workspace.shape[0],
accumulate, accumulate,
use_split_accumulator, use_split_accumulator)
) fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)
if return_output: if return_output:
if gelu: if gelu:
...@@ -102,6 +126,9 @@ def gemm( ...@@ -102,6 +126,9 @@ def gemm(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
use_bias: bool = False, use_bias: bool = False,
ub_algo: tex.UbufOverlapAlgo = None,
ub: tex.UbufCommOverlap = None,
extra_output_tensor: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]: ) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM.""" """Non FP8 GEMM."""
...@@ -142,7 +169,7 @@ def gemm( ...@@ -142,7 +169,7 @@ def gemm(
else: else:
bias_dtype = output_dtype bias_dtype = output_dtype
_ = torch.ops.tex_ts.te_gemm_ts( args = (
A, A,
empty_tensor, empty_tensor,
fp8_index, fp8_index,
...@@ -166,6 +193,28 @@ def gemm( ...@@ -166,6 +193,28 @@ def gemm(
accumulate, accumulate,
False, # use_split_accumulator False, # use_split_accumulator
) )
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,))
_ = fn(*args)
if return_output: if return_output:
return out, grad_bias, gelu_input return out, grad_bias, gelu_input
...@@ -283,9 +332,25 @@ def layernorm_fwd_fp8( ...@@ -283,9 +332,25 @@ def layernorm_fwd_fp8(
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
sm_margin: int, sm_margin: int,
zero_centered_gamma: bool zero_centered_gamma: bool,
ln_out: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output""" """LayerNorm with FP8 output"""
if ln_out is not None:
return tex.layernorm_fwd_fp8_noalloc(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
ln_out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
return tex.layernorm_fwd_fp8( return tex.layernorm_fwd_fp8(
inp, inp,
weight, weight,
...@@ -351,8 +416,20 @@ def cast_to_fp8( ...@@ -351,8 +416,20 @@ def cast_to_fp8(
fp8_meta_tensor: tex.FP8TensorMeta, fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors], fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType, otype: tex.DType,
) -> torch.Tensor: out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
"""Cast input to FP8""" """Cast input to FP8"""
if out is not None:
tex.cast_to_fp8_noalloc(
inp,
fp8_meta_tensor.scale[fp8_tensor],
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts( return torch.ops.tex_ts.cast_to_fp8_ts(
inp, inp,
fp8_meta_tensor.scale, fp8_meta_tensor.scale,
......
This diff is collapsed.
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
************************************************************************/ ************************************************************************/
#include "extensions.h" #include "extensions.h"
#ifdef NVTE_MPI_FOUND
#include "comm_gemm_overlap.h"
#endif // NVTE_MPI_FOUND
void te_gemm(at::Tensor A, void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse, at::Tensor A_scale_inverse,
...@@ -26,7 +28,8 @@ void te_gemm(at::Tensor A, ...@@ -26,7 +28,8 @@ void te_gemm(at::Tensor A,
at::Tensor workspace, at::Tensor workspace,
size_t workspaceSize, size_t workspaceSize,
bool accumulate, bool accumulate,
bool use_split_accumulator bool use_split_accumulator,
int math_sm_count
) { ) {
using namespace transformer_engine; using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(), auto te_A = makeTransformerEngineTensor(A.data_ptr(),
...@@ -70,6 +73,7 @@ void te_gemm(at::Tensor A, ...@@ -70,6 +73,7 @@ void te_gemm(at::Tensor A,
te_workspace.data(), te_workspace.data(),
accumulate, accumulate,
use_split_accumulator, use_split_accumulator,
math_sm_count,
at::cuda::getCurrentCUDAStream()); at::cuda::getCurrentCUDAStream());
} }
...@@ -536,6 +540,67 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -536,6 +540,67 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
} }
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -609,6 +674,61 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -609,6 +674,61 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
} }
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input, at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -646,6 +766,29 @@ at::Tensor cast_to_fp8(const at::Tensor &input, ...@@ -646,6 +766,29 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
} }
void cast_to_fp8_noalloc(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_fp8_quantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return;
}
at::Tensor cast_from_fp8(const at::Tensor &input, at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv, const at::Tensor &scale_inv,
transformer_engine::DType itype, transformer_engine::DType itype,
...@@ -878,6 +1021,17 @@ size_t get_cublasLt_version() { ...@@ -878,6 +1021,17 @@ size_t get_cublasLt_version() {
} }
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_MPI_FOUND
return true;
#else
return false;
#endif
}
void placeholder() {} // TODO(ksivamani) clean this up
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions // Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD"); m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
...@@ -895,8 +1049,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -895,8 +1049,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Other granular functions // Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8"); m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8");
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD"); m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD"); m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose"); m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad, m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD"); "Fused Cast + Transpose + BGRAD");
...@@ -907,6 +1063,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -907,6 +1063,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose, m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose"); "Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8"); m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8"); m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM"); m.def("te_gemm", &te_gemm, "CublasLt GEMM");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O"); m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
...@@ -914,6 +1071,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -914,6 +1071,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Misc // Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version"); m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Data structures // Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta") py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
...@@ -922,6 +1080,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -922,6 +1080,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv) .def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history); .def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
#ifdef NVTE_MPI_FOUND
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_MPI_FOUND
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations");
#endif // NVTE_MPI_FOUND
py::enum_<transformer_engine::DType>(m, "DType", py::module_local()) py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte) .value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32) .value("kInt32", transformer_engine::DType::kInt32)
......
...@@ -26,7 +26,8 @@ void te_gemm(at::Tensor A, ...@@ -26,7 +26,8 @@ void te_gemm(at::Tensor A,
at::Tensor workspace, at::Tensor workspace,
size_t workspaceSize, size_t workspaceSize,
bool accumulate, bool accumulate,
bool use_split_accumulator bool use_split_accumulator,
int math_sm_count
); );
...@@ -111,6 +112,19 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input, ...@@ -111,6 +112,19 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
const bool zero_centered_gamma const bool zero_centered_gamma
); );
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input, at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -130,6 +144,15 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input, ...@@ -130,6 +144,15 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const bool zero_centered_gamma const bool zero_centered_gamma
); );
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor layernorm_fwd_inf(const at::Tensor &input, at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &bias,
...@@ -145,6 +168,15 @@ at::Tensor cast_to_fp8(const at::Tensor &input, ...@@ -145,6 +168,15 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
); );
void cast_to_fp8_noalloc(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor cast_from_fp8(const at::Tensor &input, at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv, const at::Tensor &scale_inv,
transformer_engine::DType itype, transformer_engine::DType itype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment