Unverified Commit 18da4e88 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

TP communication overlap with userbuffers (#147)



* Port initial changes
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* readd FA include for PyTorch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Re-enable sm_70 + cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* LICENSE, cleanup header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* 5k -> 173 errors
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* license and fixes in userbuffers-host
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* next round fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* final cpp cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* pylinting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix from linting
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Turn off default async amax reduction (#148)
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code path
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* cleanup Macros
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* fix conflict resolution bug
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* Fix gencode flags in setup (#145)

* Fix gencode flags based on cuda version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review suggestions
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* revert append_nvcc_threads change
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Change overlap config dict error message
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* simplify ub initialization
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>

* lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix sanity imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* cpplint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TensorFlow build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix TE macros in public header
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* More fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* compiles with and w/o MPI
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes for python side annotations for conditional compile
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* link gdrAPI only when MPI found
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix comments for dummy var
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix linking
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* load MPI before TE
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add Py side argument checks
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove unused code and catch silent failures
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix find_lib path for tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarSangkug Lym <slym@nvidia.com>
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
parent 7bb2af35
......@@ -4,11 +4,16 @@
set -e
# Find TE
: ${TE_PATH:=/opt/transformerengine}
TE_LIB_PATH=`pip show transformer-engine | grep Location | cut -d ' ' -f 2`
export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH
# Find MPI
MPI_HOME=${MPI_HOME:-/usr/local/mpi}
NVTE_MPI_INCLUDE="$MPI_HOME/lib"
cd $TE_PATH/tests/cpp
cmake -GNinja -Bbuild .
cmake -GNinja -Bbuild -DNVTE_MPI_INCLUDE=$NVTE_MPI_INCLUDE .
cmake --build build
ctest --test-dir build -j4
......@@ -14,3 +14,4 @@ filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
......@@ -14,3 +14,4 @@ filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
......@@ -14,3 +14,4 @@ filter=-build/namespaces
filter=-readability/todo
filter=-build/header_guard
filter=-build/include
filter=-build/c++11
......@@ -19,7 +19,11 @@ from distutils.file_util import copy_file
path = os.path.dirname(os.path.realpath(__file__))
with open(path + "/VERSION", "r") as f:
te_version = f.readline()
CUDA_HOME = os.environ.get("CUDA_HOME", "/usr/local/cuda")
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi")
NVTE_MPI_FOUND = os.path.exists(MPI_HOME)
NVTE_MPI_INCLUDE = os.path.join(MPI_HOME, "include")
def get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
......@@ -51,7 +55,7 @@ def extra_gencodes(cc_flag):
def extra_compiler_flags():
return [
extra_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
......@@ -66,6 +70,9 @@ def extra_compiler_flags():
"--expt-extended-lambda",
"--use_fast_math",
]
if NVTE_MPI_FOUND:
extra_flags.append("-DNVTE_MPI_FOUND")
return extra_flags
cc_flag = []
......@@ -76,12 +83,6 @@ def make_abs_path(l):
return [os.path.join(path, p) for p in l]
include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
]
include_dirs = make_abs_path(include_dirs)
pytorch_sources = [
"transformer_engine/pytorch/csrc/extensions.cu",
"transformer_engine/pytorch/csrc/common.cu",
......@@ -100,6 +101,14 @@ supported_frameworks = {
framework = os.environ.get("NVTE_FRAMEWORK", "pytorch")
include_dirs = [
"transformer_engine/common/include",
"transformer_engine/pytorch/csrc",
]
if (framework in ("all", "pytorch")) and NVTE_MPI_FOUND:
include_dirs.append(NVTE_MPI_INCLUDE)
include_dirs = make_abs_path(include_dirs)
args = sys.argv.copy()
for s in args:
if s.startswith("--framework="):
......@@ -155,10 +164,16 @@ class PyTorchBuilder(FrameworkBuilderBase):
print("Building pyTorch extensions!")
self.pytorch_build_extensions.run()
def cmake_flags(self):
if not NVTE_MPI_FOUND:
return []
return ["-DNVTE_MPI_FOUND=1", f"-DNVTE_MPI_INCLUDE={NVTE_MPI_INCLUDE}"]
@staticmethod
def install_requires():
return ["flash-attn>=1.0.2",]
class TensorFlowBuilder(FrameworkBuilderBase):
def cmake_flags(self):
p = [d for d in sys.path if 'dist-packages' in d][0]
......@@ -167,6 +182,7 @@ class TensorFlowBuilder(FrameworkBuilderBase):
def run(self, extensions):
print("Building TensorFlow extensions!")
class JaxBuilder(FrameworkBuilderBase):
def cmake_flags(self):
p = [d for d in sys.path if 'dist-packages' in d][0]
......
......@@ -27,6 +27,12 @@ if(NOT DEFINED TE_LIB_PATH)
endif()
find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
if(EXISTS ${NVTE_MPI_INCLUDE})
find_library(MPI_LIB NAMES mpi PATHS ${NVTE_MPI_INCLUDE} REQUIRED)
message(STATUS "Found MPI library: ${MPI_LIB}")
endif()
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
include_directories(${CMAKE_SOURCE_DIR})
......
......@@ -17,7 +17,13 @@ add_executable(test_operator
test_multi_cast_transpose.cu
../test_common.cu)
target_link_libraries(test_operator PUBLIC CUDA::cudart GTest::gtest_main ${TE_LIB})
list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB})
if(EXISTS ${NVTE_MPI_INCLUDE})
list(APPEND test_operator_LINKER_LIBS ${MPI_LIB})
endif()
target_link_libraries(test_operator PUBLIC ${test_operator_LINKER_LIBS})
target_compile_options(test_operator PRIVATE -O2)
include(GoogleTest)
......
......@@ -5,7 +5,6 @@
"""Top level package"""
from . import common
try:
from . import pytorch
except ImportError as e:
......
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
add_library(transformer_engine SHARED
transformer_engine.cpp
set(transformer_engine_SOURCES)
list(APPEND transformer_engine_SOURCES transformer_engine.cpp
transpose/cast_transpose.cu
transpose/transpose.cu
transpose/cast_transpose_fusion.cu
......@@ -20,16 +21,35 @@ add_library(transformer_engine SHARED
fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu)
if(NVTE_MPI_FOUND)
list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp)
endif()
add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include")
list(APPEND transformer_engine_LINKER_LIBS CUDA::cublas CUDA::cudart CUDA::nvToolsExt)
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
if(NVTE_MPI_FOUND)
list(APPEND transformer_engine_LINKER_LIBS gdrapi)
endif()
target_link_libraries(transformer_engine PUBLIC ${transformer_engine_LINKER_LIBS})
target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
set_source_files_properties(fused_softmax/scaled_masked_softmax.cu
fused_softmax/scaled_upper_triang_masked_softmax.cu
PROPERTIES
COMPILE_OPTIONS "--use_fast_math")
if(NVTE_MPI_FOUND)
set_source_files_properties(comm_gemm_overlap/userbuffers.cu
comm_gemm_overlap/userbuffers-host.cpp
PROPERTIES
INCLUDE_DIRECTORIES ${NVTE_MPI_INCLUDE}
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:-maxrregcount=64>")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3")
......@@ -37,4 +37,27 @@ def _load_library():
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
def _load_mpi():
"""Load MPI shared library"""
system = platform.system()
if system == "Linux":
extension = "so"
elif system == "Darwin":
extension = "dylib"
elif system == "Windows":
extension = "dll"
else:
raise RuntimeError(f"Unsupported operating system ({system})")
lib_name = "libmpi." + extension
MPI_HOME = os.environ.get("MPI_HOME", "/usr/local/mpi")
NVTE_MPI_FOUND = os.path.exists(MPI_HOME)
dll_path = os.path.join(MPI_HOME, "lib", lib_name)
if NVTE_MPI_FOUND:
return ctypes.CDLL(dll_path, mode=ctypes.RTLD_GLOBAL)
return None
_TE_LIB_CTYPES = _load_mpi()
_TE_LIB_CTYPES = _load_library()
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <assert.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <immintrin.h>
#include <math.h>
#include <mpi.h>
#include <sched.h>
#include <stdio.h>
#include <string.h>
#include <transformer_engine/userbuffers.h>
#include <transformer_engine/logging.h>
#include <unistd.h>
#include <x86intrin.h>
#include <chrono>
#include <iostream>
static int oob_bcast(void *comm_context, void *buf, int size, int root) {
MPI_Bcast(buf, size, MPI_BYTE, root,
(reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
static int oob_barrier(void *comm_context) {
MPI_Barrier((reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
static int oob_gather(void *comm_context, int root, void *sbuf, void *rbuf, int len) {
MPI_Gather(sbuf, len, MPI_BYTE, rbuf, len, MPI_BYTE, root,
(reinterpret_cast<communicator *>(comm_context))->comm_inter);
return 0;
}
int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (const char *)b); }
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
int pipe_rank(communicator *comm, int step) {
int mynode = comm->myrank / comm->nvsize;
int mylocal = comm->nvrank;
int numlocal = comm->nvsize;
int newlocal1 = mylocal + step * comm->ar_nvsize * comm->ar2_nvsize;
int newlocal = (numlocal + (newlocal1 % numlocal)) % numlocal;
int newnode = mynode;
newnode += (newlocal1 - newlocal) / numlocal * comm->num_nodes * comm->num2_nodes;
int allnodes = comm->nranks / comm->nvsize;
newnode = (allnodes + (newnode % allnodes)) % allnodes;
return newnode * numlocal + newlocal;
}
int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes) {
*comm = reinterpret_cast<communicator *>(malloc(sizeof(communicator)));
int myrank, nranks, cur_dev, ndev;
MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
MPI_Comm_size(MPI_COMM_WORLD, &nranks);
(*comm)->nranks = nranks;
(*comm)->myrank = myrank;
(*comm)->free_region = 0;
(*comm)->launch_mode = NVTE_LAUNCH_GPU | NVTE_LAUNCH_CPU;
cudaDeviceProp device_prop;
CUDACHECK(cudaGetDevice(&cur_dev));
CUDACHECK(cudaGetDeviceCount(&ndev));
CUDACHECK(cudaGetDeviceProperties(&device_prop, cur_dev));
(*comm)->sm_arch = device_prop.major;
// (*comm)->use_rr_kernel = device_prop.major == 8;
(*comm)->use_rr_kernel = 0;
(*comm)->push = 1;
(*comm)->use_ce = 0;
(*comm)->cga_size = 2;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->basecounter[i] = 0;
(*comm)->head = 0;
(*comm)->tail = 0;
(*comm)->activeproxy = 1;
(*comm)->active_nreqs = 0;
for (int i = 0; i < userbuffers_op_types; i++) (*comm)->active_req[i].active = -1;
int ret = 0;
// split communicator
char host_name[MPI_MAX_PROCESSOR_NAME];
char(*host_names)[MPI_MAX_PROCESSOR_NAME];
int namelen, bytes, color, my_node, mylocal, numlocal, num_nodes;
int rank = (*comm)->myrank, size = (*comm)->nranks;
MPI_Get_processor_name(host_name, &namelen);
bytes = size * sizeof(char[MPI_MAX_PROCESSOR_NAME]);
host_names = (char(*)[MPI_MAX_PROCESSOR_NAME])malloc(bytes);
strcpy(host_names[rank], host_name); // NOLINT(*)
for (int n = 0; n < size; n++)
MPI_Bcast(&(host_names[n]), MPI_MAX_PROCESSOR_NAME, MPI_CHAR, n, MPI_COMM_WORLD);
qsort(host_names, size, sizeof(char[MPI_MAX_PROCESSOR_NAME]), stringCmp);
color = 0;
for (int n = 0; n < size; n++) {
if (n > 0 && strcmp(host_names[n - 1], host_names[n])) color++;
if (strcmp(host_name, host_names[n]) == 0) break;
}
free(host_names);
MPI_Comm_split(MPI_COMM_WORLD, color, rank, &(*comm)->comm_intra);
// find intranode numbers and make internode communicator
// figure out mylocal
MPI_Comm_rank((*comm)->comm_intra, &mylocal);
MPI_Comm_size((*comm)->comm_intra, &numlocal);
(*comm)->nvrank = mylocal;
(*comm)->nvsize = numlocal;
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
int core;
if (mylocal == 0) core = 50;
if (mylocal == 1) core = 58;
if (mylocal == 2) core = 18;
if (mylocal == 3) core = 26;
if (mylocal == 4) core = 114;
if (mylocal == 5) core = 122;
if (mylocal == 6) core = 82;
if (mylocal == 7) core = 90;
CPU_SET(core, &cpuset);
if (!getenv("NVTE_NODOUBLE")) {
if (core > 128)
CPU_SET(core - 128, &cpuset);
else
CPU_SET(core + 128, &cpuset);
}
if (getenv("NVTE_DOPIN")) pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset);
if (ndev == numlocal) { // all visible devices
if (cur_dev != mylocal)
printf("%d: device used %d[%d] ,resetting device to %d\n", rank, cur_dev, ndev, mylocal);
CUDACHECK(cudaSetDevice(mylocal));
}
(*comm)->mydev = cur_dev;
// FIXME need to check that numlocal is multiple of pipegpus x tensorgpus
// ar1 is data
int divgpus = pipegpus * tensorgpus;
int datagpus = numlocal / divgpus;
(*comm)->ar_nvsize = datagpus;
(*comm)->ar_firstgpu = mylocal - ((mylocal / tensorgpus) % datagpus) * tensorgpus;
(*comm)->ar_nvrank = (mylocal - (*comm)->ar_firstgpu) / tensorgpus;
// ar2 is tensor
(*comm)->ar2_nvsize = tensorgpus;
(*comm)->ar2_firstgpu = mylocal - mylocal % tensorgpus;
(*comm)->ar2_nvrank = mylocal - (*comm)->ar2_firstgpu;
// ar2 has step equal to ar_nvsize
int allnodes = nranks / numlocal;
int mynode = myrank / numlocal;
int datanodes = allnodes / pipenodes / tensornodes;
int pipenodegroup_id = myrank / numlocal / (datanodes * tensornodes);
(*comm)->pipe_id = pipegpus * pipenodegroup_id + mylocal / (datagpus * tensorgpus);
CUDACHECK(cudaFree(0));
int datanodegroup_id =
myrank / numlocal / datanodes; // data reduction group node belongs, equals 0 for all if both
// pipenodes=1 and tensornodes=1
// mpi communicator only needed for SHARP which is always allreduce1/data-parallel
MPI_Comm_split(MPI_COMM_WORLD, mylocal + numlocal * datanodegroup_id, rank, &(*comm)->comm_inter);
// different rails from same group are in different subcommunicators
MPI_Comm_size((*comm)->comm_inter, &num_nodes);
MPI_Comm_rank((*comm)->comm_inter, &my_node);
(*comm)->first_node = mynode - my_node;
(*comm)->num_nodes = num_nodes;
(*comm)->my_node = my_node;
(*comm)->num2_nodes = tensornodes;
(*comm)->my2_node = (mynode / datanodes) % tensornodes;
(*comm)->first2_node = mynode - (*comm)->my2_node * datanodes;
char *ib_dev_list;
int ZIONROCE = getenv("NVTE_ZIONROCE") ? atoi(getenv("NVTE_ZIONROCE")) : 0;
int ROCE = getenv("NVTE_ROCE") ? atoi(getenv("NVTE_ROCE")) : 0;
if (ZIONROCE) ROCE = 1;
int DGX_H100 = device_prop.major == 9;
switch (mylocal) {
case 0:ib_dev_list = "mlx5_0:1"; break; // NOLINT(*)
case 1:ib_dev_list = (char*)(DGX_H100?"mlx5_3:1":"mlx5_1:1"); break; // NOLINT(*)
case 2:ib_dev_list = (char*)(ZIONROCE?"mlx5_4:1":DGX_H100?"mlx5_4:1":"mlx5_2:1"); break; // NOLINT(*)
case 3:ib_dev_list = (char*)(DGX_H100?"mlx5_5:1":"mlx5_3:1"); break; // NOLINT(*)
case 4:ib_dev_list = (char*)(DGX_H100?"mlx5_6:1":"mlx5_6:1"); break; // NOLINT(*)
case 5:ib_dev_list = (char*)(DGX_H100?"mlx5_9:1":"mlx5_7:1"); break; // NOLINT(*)
case 6:ib_dev_list = (char*)(ZIONROCE?"mlx5_10:1":DGX_H100?"mlx5_10:1":"mlx5_8:1"); break; // NOLINT(*)
case 7:ib_dev_list = (char*)(DGX_H100?"mlx5_11:1":"mlx5_9:1"); break; // NOLINT(*)
default: break;
}
(*comm)->fifo = reinterpret_cast<ub_request *>(malloc(sizeof(ub_request) * NVTE_MAX_REQUESTS));
(*comm)->nblocks = 8;
(*comm)->alignblock = 1024 * 512;
(*comm)->minblock = 1024 * 2 * 1024;
(*comm)->asyncblocks = 16;
CUDACHECK(cudaMallocHost((void **)&(*comm)->hostflags, // NOLINT(*)
(NVTE_MAX_SMS + 100) * sizeof(int)));
for (int i = 0; i < 100 + NVTE_MAX_SMS; i++) (*comm)->hostflags[i] = 0;
_mm_mfence();
sleep(1);
// init_p2p_transport();
(*comm)->ibnvsize = (*comm)->nvsize;
#define NBUF 2
#define LOCALSIZE 4 * (NVTE_REG0_OFFSET(*comm) + NVTE_REG0_FLAGS + NVTE_REG0_COMMBUFFER * NBUF)
// peer pointers + op flags + comm buffer
CUDACHECK(cudaMalloc(&(*comm)->gpu_ptrs, LOCALSIZE)); // flags and pointers, no block data yet
CUDACHECK(cudaMemset((*comm)->gpu_ptrs, 0, LOCALSIZE));
CUDACHECK(cudaDeviceSynchronize());
register_user_buffer_collective(&((*comm)->gpu_ptrs), LOCALSIZE, *comm); // will use handler 0
CUDACHECK(cudaMalloc(&(*comm)->send_id, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMalloc(&(*comm)->recv_id, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->send_id, 0, (*comm)->nranks * sizeof(int)));
CUDACHECK(cudaMemset((*comm)->recv_id, 0, NVTE_MAX_REGIONS * (*comm)->nranks * sizeof(int)));
(*comm)->sms = 16;
(*comm)->threads = 1024;
#define GPU_PAGE_SHIFT 16
#define GPU_PAGE_SIZE (1UL << GPU_PAGE_SHIFT)
#define GPU_PAGE_OFFSET (GPU_PAGE_SIZE - 1)
#define GPU_PAGE_MASK (~GPU_PAGE_OFFSET)
CUDACHECK(cudaMalloc(&(*comm)->flags, 2 * GPU_PAGE_SIZE));
unsigned int flag = 1;
// cuPointerSetAttribute(&flag, CU_POINTER_ATTRIBUTE_SYNC_MEMOPS, (CUdeviceptr)(*comm)->flags);
CUDACHECK(cudaMemset((*comm)->flags, 0, 2 * GPU_PAGE_SIZE));
(*comm)->flags =
reinterpret_cast<int *>(((CUdeviceptr)(*comm)->flags + GPU_PAGE_SIZE - 1) & GPU_PAGE_MASK);
using namespace std;
(*comm)->g = gdr_open();
if ((*comm)->g == NULL) {
fprintf(stderr, "gdrcopy open failed\n");
return -1;
}
gdr_mh_t mh;
ret = gdr_pin_buffer((*comm)->g, (CUdeviceptr)(*comm)->flags, GPU_PAGE_SIZE, 0, 0, &mh);
if (ret) {
fprintf(stderr, "gdr_pin_buffer failed\n");
return -1;
}
ret = gdr_map((*comm)->g, mh, (void **)&((*comm)->map_flags), GPU_PAGE_SIZE); // NOLINT(*)
if (ret) {
fprintf(stderr, "gdr_map failed\n");
return -1;
}
sched_param param;
pthread_attr_t attr;
pthread_attr_init(&attr);
pthread_attr_getschedparam(&attr, &param);
param.sched_priority = sched_get_priority_max(SCHED_FIFO);
pthread_attr_setschedparam(&attr, &param);
if (getenv("NVTE_UBDEBUG"))
printf("%d/%d:(%d x %d): DP %d x %d TP %d x %d, DPGROUP %dx%d TPGROUP %dx%d PIPE_ID %d/%d\n",
myrank, nranks, myrank / numlocal, myrank % numlocal, (*comm)->my_node,
(*comm)->ar_nvrank, (*comm)->my2_node, (*comm)->ar2_nvrank, (*comm)->num_nodes,
(*comm)->ar_nvsize, (*comm)->num2_nodes, (*comm)->ar2_nvsize, (*comm)->pipe_id,
pipegpus * pipenodes);
fflush(NULL);
return 0;
}
int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes) {
return create_communicator_grouped2(comm, pipegpus, pipenodes, 1, 1);
}
int create_communicator(communicator **comm) {
return create_communicator_grouped2(comm, 1, 1, 1, 1);
}
void destroy_communicator(communicator *comm) {
comm->activeproxy = 0;
if (!comm->myrank && getenv("NVTE_UBDEBUG"))
printf("waiting for userbuffers proxy thread to exit()\n");
gdr_close(comm->g);
}
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) {
if (comm->free_region > NVTE_MAX_REGIONS) return -1;
int hndl = comm->free_region;
// printf("%d register %d size %lld\n",comm->myrank,hndl,bytes);fflush(NULL);
comm->peer_ptr[hndl] = reinterpret_cast<void **>(malloc(sizeof(void *) * (comm->nvsize)));
if (alloc) {
CUDACHECK(cudaMalloc(gpubuff, bytes));
}
assert(comm->nvsize <= 8);
cudaIpcMemHandle_t *memhndl =
reinterpret_cast<cudaIpcMemHandle_t *>(malloc(sizeof(cudaIpcMemHandle_t) * (comm->nvsize)));
CUDACHECK(cudaIpcGetMemHandle(&memhndl[comm->nvrank], *gpubuff));
MPI_Allgather(&memhndl[comm->nvrank], sizeof(cudaIpcMemHandle_t), MPI_BYTE, memhndl,
sizeof(cudaIpcMemHandle_t), MPI_BYTE, comm->comm_intra);
for (int i = 0; i < comm->nvsize; i++)
if (i != comm->nvrank)
CUDACHECK(cudaIpcOpenMemHandle((void **)&(comm->peer_ptr[hndl][i]), // NOLINT(*)
memhndl[i], cudaIpcMemLazyEnablePeerAccess));
comm->peer_ptr[hndl][comm->nvrank] = *gpubuff;
CUDACHECK(cudaDeviceSynchronize());
CUDACHECK(
cudaMemcpy(reinterpret_cast<char *>(comm->gpu_ptrs) + (hndl * comm->nvsize * sizeof(void *)),
comm->peer_ptr[hndl], comm->nvsize * sizeof(void *), cudaMemcpyHostToDevice));
CUDACHECK(cudaDeviceSynchronize());
free(memhndl);
comm->mem_ptr[hndl] = *gpubuff;
return comm->free_region++;
}
int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements,
const int blocksize, communicator *comm, cudaStream_t stream);
int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op);
int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op);
int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op);
void allreduce_nonsharp_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream, int op) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
// if(comm->myrank==0) fprintf(stderr,"AR2(%d) user call launch_mode=%d\n",op,comm->launch_mode);
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
// if(maxcredit>4) maxcredit=4;
// if(maxcredit>4 && ar_nvsize==1) maxcredit=4;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
// blocksize=elements*2;
int sms = allreduce2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
comm->fifo[comm->head].maxcredit = maxcredit;
comm->fifo[comm->head].handler = handler;
comm->fifo[comm->head].offset = offset;
comm->fifo[comm->head].elements = elements;
int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1);
while (newhead == comm->tail) {
}
comm->head = newhead;
comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize;
}
}
void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp2);
}
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
allreduce_nonsharp_inplace(handler, offset, elements, comm, stream,
userbuffers_allreduceop_nonsharp);
return;
}
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
int sms = reducescatter2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize,
comm, stream, op);
if (num_nodes > 1 && comm->launch_mode & NVTE_LAUNCH_CPU) {
if (!sms) return;
comm->fifo[comm->head].optype = op;
comm->fifo[comm->head].basecounter = comm->basecounter[op];
comm->fifo[comm->head].blocksize = blocksize;
comm->fifo[comm->head].maxcredit = maxcredit;
comm->fifo[comm->head].handler = handler;
comm->fifo[comm->head].offset = offset;
comm->fifo[comm->head].elements = elements;
int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1);
while (newhead == comm->tail) {
}
comm->head = newhead;
comm->basecounter[op] += (elements * 2 + blocksize - 1) / blocksize;
}
}
void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
if (elements < 64) NVTE_ERROR("Userbuffer comm for given config not implemented.");
int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int blocksize = elements * 2;
int maxcredit = 0;
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
blocksize = (comm->nblocks - 1 + (comm->alignblock - 1 + elements * 2) / comm->alignblock) /
comm->nblocks; // FIXME TUNING
blocksize *= comm->alignblock;
if (blocksize < comm->minblock) blocksize = comm->minblock;
maxcredit = (elements * 2 + blocksize - 1) / blocksize;
size_t peerblock = sizeof(int) * NVTE_REG0_COMMBUFFER / maxcredit; // max size we can fit
if (blocksize > peerblock * ar_nvsize) blocksize = peerblock * ar_nvsize;
int sms = allgather2_userbuff_inplace_gpu(maxcredit, handler, offset, elements, blocksize, comm,
stream, op);
}
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cuda_runtime.h>
#if __CUDA_ARCH__ >= 800
#include <cuda_bf16.h>
#define half nv_bfloat16
#else
#include <cuda_fp16.h>
#endif
#include <assert.h>
#include <stdio.h>
#include <transformer_engine/userbuffers.h>
#define MAX_THREADS 1024
#define TIMEOUT 200000000000ull
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
// if(blockIdx.x==0 && threadIdx.x==0) printf("%d/%d(phys %d gpustep %d firstrank %d):RRkernel(d)
// start, size %lld\n",myrank,RANKS,gpustep*myrank+firstrank,gpustep,firstrank,numlines*16ull);
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
reduce_id++;
}
__syncthreads();
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines;
line += blockDim.x * gridDim.x * RANKS) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
// int dest = (i+myrank+warp)&(RANKS-1);
val[i] = userptr[dest[i]][lineoffset + line];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
#pragma unroll
for (int i = 0; i < RANKS; i++) {
// int dest = (i+myrank+warp)&(RANKS-1);
userptr[dest[i]][lineoffset + line] = sum;
}
}
__syncthreads();
if (threadIdx.x == 0) __threadfence_system();
__syncthreads();
if (threadIdx.x < RANKS) {
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
}
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Volta,Hopper)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep, const int lineoffset,
const int numlines, void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
reduce_id++;
}
__syncthreads();
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * (myrank + RANKS * blockIdx.x); line < numlines;
line += blockDim.x * gridDim.x * RANKS) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][lineoffset + line];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
userptr[myrank][lineoffset + line] = sum;
}
__syncthreads();
if (threadIdx.x == 0) __threadfence();
__syncthreads();
if (threadIdx.x < RANKS) {
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
}
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++) {
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank) {
skipmy++;
continue;
}
dest[i - skipmy] = dst;
}
__syncthreads();
for (int line = threadIdx.x + blockDim.x * RANKS * blockIdx.x; line < numlines;
line += blockDim.x * gridDim.x * RANKS) {
int4 val[RANKS - 1];
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
val[i] = userptr[dest[i]][lineoffset + line + blockDim.x * dest[i]];
}
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
userptr[myrank][lineoffset + line + blockDim.x * dest[i]] = val[i];
}
}
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Ampere)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_rs(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
}
__syncthreads();
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][mylineoffset + line];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
userptr[myrank][mylineoffset + line] = sum;
}
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id;
} // fp16 inplace reduce-scatter kernel
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_rs_oop(const int op, const int flagoffset,
const int firstrank, const int myrank,
const int gpustep, const int mylineoffset,
const int totallines, const int rowlines,
const int skiplines, void **commbuff,
const int handleridx, void *outbuf) {
__shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > TIMEOUT) {
printf("NVONLY RSBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
}
__syncthreads();
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) dest[i] = (i + myrank + warp) & (RANKS - 1);
__syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[dest[i]][mylineoffset + line];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 8; j++) s[j] += x[j];
}
(reinterpret_cast<int4 *>(outbuf))[(line / rowlines) * skiplines + (line % rowlines)] = sum;
}
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id;
} // fp16 reduce-scatter kernel (out of place)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_ag(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&(myptr[targetgpu]);
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
clock_t s = clock64();
}
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS];
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++) {
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank) {
skipmy++;
continue;
}
dest[i - skipmy] = dst;
}
__syncthreads();
for (int line = threadIdx.x + blockDim.x * blockIdx.x; line < totallines;
line += blockDim.x * gridDim.x) {
int4 val[RANKS - 1];
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
val[i] = userptr[dest[i]][mylineoffset + line + totallines * dest[i]];
}
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
userptr[myrank][mylineoffset + line + totallines * dest[i]] = val[i];
}
}
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Ampere)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw_ag(const int op, const int flagoffset, const int firstrank,
const int myrank, const int gpustep,
const int mylineoffset, const int totallines,
void **commbuff, const int handleridx) {
__shared__ int4 *userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
int4 *localptr;
if (threadIdx.x < RANKS) {
physgpu = myrank * gpustep + firstrank;
targetgpu = threadIdx.x * gpustep + firstrank;
const int blockflagoffset = NVTE_MAX_NVLINK * 2 * blockIdx.x;
myptr = (reinterpret_cast<int *>(commbuff[physgpu])) + flagoffset;
reduceidptr = myptr - NVTE_MAX_OPS; // +op;
reduce_id = (*reduceidptr) + 1;
flagptr = (reinterpret_cast<int *>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
userptr[threadIdx.x] = reinterpret_cast<int4 *>(commbuff[targetgpu + handleridx]);
reduce_id++;
}
__syncthreads();
localptr = userptr[myrank];
int warp = blockIdx.x + (threadIdx.x >> 5);
int dest[RANKS - 1];
int skipmy = 0;
#pragma unroll
for (int i = 0; i < RANKS; i++) {
int dst = (i + warp + myrank) & (RANKS - 1);
if (dst == myrank) {
skipmy++;
continue;
}
dest[i - skipmy] = dst;
}
#define UNROLLAG 4
__syncthreads();
const int loop_step0 = blockDim.x * gridDim.x;
const int loop_step = loop_step0 * UNROLLAG;
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = max(start_elem, totallines);
const int aligned_elem = ((end_elem - start_elem) / loop_step) * loop_step;
const int end_aligned = start_elem + aligned_elem;
for (int line = start_elem; line < end_aligned; line += loop_step) {
int4 val[UNROLLAG];
#pragma unroll
for (int j = 0; j < UNROLLAG; j++) val[j] = localptr[mylineoffset + line + loop_step0 * j];
#pragma unroll
for (int j = 0; j < UNROLLAG; j++)
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
userptr[dest[i]][mylineoffset + line + j * loop_step0] = val[j];
}
}
for (int line = end_aligned; line < end_elem; line += loop_step0) {
int4 sum = localptr[mylineoffset + line];
#pragma unroll
for (int i = 0; i < RANKS - 1; i++) {
userptr[dest[i]][mylineoffset + line] = sum;
}
}
__syncthreads();
if (threadIdx.x == 0) __threadfence_system();
__syncthreads();
if (threadIdx.x < RANKS) {
flagptr[physgpu] = reduce_id;
volatile int *flag = (volatile int *)&myptr[targetgpu];
clock_t s = clock64();
while (*flag < reduce_id) {
if (clock64() - s > 2ull * TIMEOUT) {
printf("NVONLY AGBAR:SM %d [%d]:expecting %d got %d\n", blockIdx.x, threadIdx.x, reduce_id,
*flag);
break;
}
}
}
if (threadIdx.x == 0 && blockIdx.x == 0) *reduceidptr = reduce_id;
} // fp16 inplace allgather kernel (Volta,Hopper)
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr_blocked(const int op, const int flagoffset,
const int firstrank, const int myrank,
const int lineoffset, const int numlines,
void **commbuff, const int handleridx,
const int peerblocklines, int *hostflags,
int *gpuflag, const int numblocks) {
const int basecounter = gpuflag[NVTE_GF_STATE + op];
#define REDUCETHREADS (blockDim.x - 32)
if (threadIdx.x < 32) {
int *flagptr;
if (threadIdx.x < RANKS) {
if (!blockIdx.x) {
flagptr = reinterpret_cast<int *>(commbuff[threadIdx.x + firstrank]);
flagptr[flagoffset + myrank + firstrank] = basecounter;
}
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[myrank + firstrank]))[flagoffset + threadIdx.x + firstrank]);
while (*flag < basecounter) {
}
}
__syncthreads();
int startblock = 0, endblock = numblocks;
for (int nblock = 0; nblock < endblock; nblock++) {
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
if (threadIdx.x == 0) {
__threadfence();
if (blockIdx.x) gpuflag[op * NVTE_MAX_SMS * 2 + blockIdx.x] = nblock + basecounter + 1;
} else if (blockIdx.x == 0) {
int expecting = (basecounter + nblock + 1);
if (threadIdx.x < gridDim.x)
while (((volatile int *)gpuflag)[op * NVTE_MAX_SMS * 2 + threadIdx.x] < expecting) {
}
}
if (!blockIdx.x) {
asm volatile("bar.sync 15, %0;" ::"r"(32));
if (!threadIdx.x) hostflags[0] = nblock + basecounter + 1;
}
}
int cachedflag = basecounter;
#define ALLGATHERFLAG NVTE_GF_IBSHARPDONE
if (blockIdx.x == 0 && threadIdx.x < RANKS) {
while (cachedflag < basecounter + numblocks) {
int newflag = ((volatile int *)gpuflag)[ALLGATHERFLAG];
if (newflag == cachedflag) continue;
cachedflag = newflag;
flagptr[flagoffset + myrank + 32 + firstrank] = cachedflag;
}
}
if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks;
} else {
const int warp = blockIdx.x + (threadIdx.x >> 5);
int4 *userptr[RANKS];
int4 *userptrmyrank;
#pragma unroll
for (int i = 0; i < RANKS; i++)
userptr[i] = reinterpret_cast<int4 *>(
commbuff[((i + myrank + warp) & (RANKS - 1)) + handleridx + firstrank]);
userptrmyrank = reinterpret_cast<int4 *>(commbuff[myrank + handleridx + firstrank]);
__syncthreads();
int blocklineoffset = 0;
while (blocklineoffset < numlines) {
const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
const int blockstart = lineoffset + blocklineoffset + blocklines * myrank;
for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines;
line += REDUCETHREADS * gridDim.x) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[i][blockstart + line];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j];
}
userptrmyrank[blockstart + line] = sum;
} // single block loop
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
blocklineoffset += peerblocklines * RANKS;
} // block loop NVLINK-REDUCESCATTER
const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1);
const int myblockDim = nwarps << 5;
const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1);
const int maxthreadIdx = myblockDim * (RANKS - 1) + 32;
const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1);
const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31);
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[myrank + firstrank]))[flagoffset + mydest + 32 + firstrank]);
int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)];
blocklineoffset = 0;
int gathercounter = basecounter + 1;
while (blocklineoffset < numlines) {
const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
const int blockstart = lineoffset + blocklineoffset;
#define UNROLL 6
int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest];
int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest];
if (threadIdx.x < maxthreadIdx) {
const int start_elem = mythreadIdx + myblockDim * blockIdx.x;
const int end_elem = max(start_elem, blocklines);
const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) *
(myblockDim * gridDim.x * UNROLL);
const int end_aligned = start_elem + aligned_elem;
if (mythreadIdx == 0) {
while (*flag < gathercounter) {
}
gathercounter++;
}
asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim));
for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) {
int4 val[UNROLL];
#pragma unroll
for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x];
#pragma unroll
for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i];
}
for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x)
myptr[line] = peerptr[line];
}
blocklineoffset += peerblocklines * RANKS;
} // block loop for NVLINK-ALLGATHER
} // worker warps else block
} // fp16 inplace reduce kernel with SHARP / in blocks
// threadfence and SMs sync to SM0
#define SMBAR(offset, block) \
asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x)); \
if (threadIdx.x == 0) { \
__threadfence_system(); \
if (blockIdx.x) gpuflag[offset + blockIdx.x] = block + basecounter + 1; \
} else if (blockIdx.x == 0) { \
int expecting = (basecounter + block + 1); \
if (threadIdx.x < gridDim.x) \
while (((volatile int *)gpuflag)[offset + threadIdx.x] < expecting) { \
} \
} \
if (blockIdx.x == 0) asm volatile("bar.sync 15, %0;" ::"r"(32));
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2(
const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks,
const int commbufoffset, const int flagoffset, const int firstrank, const int myrank,
const int gpustep, const int lineoffset, const int numlines, void **commbuff,
const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag,
const int numblocks) {
const int basecounter = gpuflag[NVTE_GF_STATE + op];
if (threadIdx.x < 32) {
int *flagptr;
volatile int *localflag = (volatile int *)&(
((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*)
// initial intranode barrier - once
if (threadIdx.x < RANKS) {
if (!blockIdx.x) {
flagptr = reinterpret_cast<int *>(commbuff[gpustep * threadIdx.x + firstrank]);
flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter;
}
volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank];
while (*flag < basecounter) {
}
}
__syncthreads();
for (int nblock = 0; nblock < numblocks + headstart; nblock++) {
if (nblock < numblocks) {
// RS happens here
SMBAR(op * 2 * NVTE_MAX_SMS, nblock);
if (!blockIdx.x && !threadIdx.x)
hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1;
}
if (nblock >= headstart) {
for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32)
if (ibflag != myibrank)
while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) {
}
asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x));
// REDUCE happens here
SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart);
if (!blockIdx.x && !threadIdx.x)
hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = nblock + basecounter + 1 - headstart;
}
}
// final part doing NVAG based on responses from NIC-RMW:IBAG
if (blockIdx.x == 0) {
for (int nblock = 0; nblock < numblocks; nblock++) {
const int expected = basecounter + nblock + 1;
for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32)
if (ibflag != myibrank)
while (localflag[NVTE_REG0_IBAG + ibflag] < expected) {
}
asm volatile("bar.sync 15, %0;" ::"r"(32));
if (threadIdx.x < RANKS)
flagptr[flagoffset + gpustep * myrank + NVTE_MAX_NVLINK + firstrank] = expected;
}
}
if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks;
} else { // sync warp
// reducethreads
const int warp = blockIdx.x + (threadIdx.x >> 5);
int4 *userptr[RANKS];
int4 *userptrmyrank;
#pragma unroll
for (int i = 0; i < RANKS; i++)
userptr[i] = reinterpret_cast<int4 *>(
commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]);
userptrmyrank = reinterpret_cast<int4 *>(commbuff[gpustep * myrank + handleridx + firstrank]);
int4 *internalbuf = reinterpret_cast<int4 *>(commbuff[myrank * gpustep + firstrank] +
commbufoffset * sizeof(int));
__syncthreads();
int blocklineoffset = 0, rblocklineoffset = 0;
for (int nblock = 0; nblock < numblocks + headstart; nblock++) {
// NVRS part(only first numblocks steps)
if (blocklineoffset < numlines) {
const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
const int blockstart = lineoffset + blocklineoffset + blocklines * myrank;
if (RANKS > 1) {
for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines;
line += REDUCETHREADS * gridDim.x) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[i][blockstart + line];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j];
}
userptrmyrank[blockstart + line] = sum;
} // single block loop
}
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
blocklineoffset += peerblocklines * RANKS;
}
if (nblock >= headstart) {
#define UNROLLRS 2
const int remainder = min(numlines - rblocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
rblocklineoffset += peerblocklines * RANKS;
const int ibblocklines = blocklines / ibranks;
int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines];
const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS +
myrank * blocklines + ibblocklines * myibrank;
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < ibblocklines;
line += REDUCETHREADS * gridDim.x) {
int4 val[UNROLLRS];
#pragma unroll
for (int i = 0; i < UNROLLRS; i++)
val[i] = i == myibrank ? userptrmyrank[tempstart + line]
: tempbufptr[i * ibblocklines + line];
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
for (int i = 0; i < ibranks - UNROLLRS; i++) {
val[i % UNROLLRS] = i == myibrank ? userptrmyrank[tempstart + line]
: tempbufptr[i * ibblocklines + line];
half *x = reinterpret_cast<half *>(&val[(i + 1) % UNROLLRS]);
#pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j];
}
#pragma unroll
for (int i = 1; i < UNROLLRS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j];
}
userptrmyrank[tempstart + line] = sum;
}
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
}
} // nblock loop NVLINK-REDUCESCATTER + IBREDUCE LOCAL COMPUTE
if (RANKS != 1) {
const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1);
const int myblockDim = nwarps << 5;
const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1);
const int maxthreadIdx = myblockDim * (RANKS - 1) + 32;
const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1);
const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31);
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[gpustep * myrank + firstrank]))[flagoffset + gpustep * mydest + NVTE_MAX_NVLINK +
firstrank]);
int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)];
blocklineoffset = 0;
int gathercounter = basecounter + 1;
while (blocklineoffset < numlines) {
const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
const int blockstart = lineoffset + blocklineoffset;
#define UNROLL 6
int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest];
int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest];
if (threadIdx.x < maxthreadIdx) {
const int start_elem = mythreadIdx + myblockDim * blockIdx.x;
const int end_elem = max(start_elem, blocklines);
const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) *
(myblockDim * gridDim.x * UNROLL);
const int end_aligned = start_elem + aligned_elem;
if (mythreadIdx == 0) {
while (*flag < gathercounter) {
}
gathercounter++;
}
asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim));
for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) {
int4 val[UNROLL];
#pragma unroll
for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x];
#pragma unroll
for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i];
}
for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x)
myptr[line] = peerptr[line];
}
blocklineoffset += peerblocklines * RANKS;
} // block loop for NVLINK-ALLGATHER
} // RANKS!=1
} // worker warps else block
} // fp16 inplace reduce kernel with SHARP / in blocks
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs(
const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks,
const int commbufoffset, const int flagoffset, const int firstrank, const int myrank,
const int gpustep, const int lineoffset, const int numlines, void **commbuff,
const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag,
const int numblocks) {
const int basecounter = gpuflag[NVTE_GF_STATE + op];
if (threadIdx.x < 32) {
int *flagptr;
volatile int *localflag = (volatile int *)&(
((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*)
// initial intranode barrier - once
if (threadIdx.x < RANKS) {
if (!blockIdx.x) {
flagptr = reinterpret_cast<int *>(commbuff[gpustep * threadIdx.x + firstrank]);
flagptr[flagoffset + gpustep * myrank + firstrank] = basecounter;
}
volatile int *flag = &localflag[gpustep * threadIdx.x + firstrank];
while (*flag < basecounter) {
}
}
__syncthreads();
for (int nblock = 0; nblock < numblocks + headstart; nblock++) {
if (nblock < numblocks) {
// RS happens here
SMBAR(op * 2 * NVTE_MAX_SMS, nblock);
if (!blockIdx.x && !threadIdx.x)
hostflags[NVTE_HF_NVRSDONE + (op & 1)] = nblock + basecounter + 1;
}
if (nblock >= headstart) {
for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32)
if (ibflag != myibrank)
while (localflag[NVTE_REG0_IBRS + ibflag] < basecounter + nblock - headstart + 1) {
}
asm volatile("bar.sync 13, %0;" ::"r"(blockDim.x));
// REDUCE happens here
SMBAR(op * 2 * NVTE_MAX_SMS + NVTE_MAX_SMS, nblock - headstart);
}
}
} else { // sync warp
// reducethreads
const int warp = blockIdx.x + (threadIdx.x >> 5);
int4 *userptr[RANKS];
int4 *userptrmyrank;
#pragma unroll
for (int i = 0; i < RANKS; i++)
userptr[i] = reinterpret_cast<int4 *>(
commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]);
userptrmyrank = reinterpret_cast<int4 *>(commbuff[gpustep * myrank + handleridx + firstrank]);
int4 *internalbuf = reinterpret_cast<int4 *>(commbuff[myrank * gpustep + firstrank] +
commbufoffset * sizeof(int));
__syncthreads();
int blocklineoffset = 0, rblocklineoffset = 0;
for (int nblock = 0; nblock < numblocks + headstart; nblock++) {
// NVRS part(only first numblocks steps)
if (blocklineoffset < numlines) {
const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
const int blockstart = lineoffset + blocklineoffset + blocklines * myrank;
if (RANKS > 1) {
for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < blocklines;
line += REDUCETHREADS * gridDim.x) {
int4 val[RANKS];
#pragma unroll
for (int i = 0; i < RANKS; i++) {
val[i] = userptr[i][blockstart + line];
}
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
#pragma unroll
for (int i = 1; i < RANKS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < sizeof(int4) / sizeof(half); j++) s[j] += x[j];
}
userptrmyrank[blockstart + line] = sum;
} // single block loop
}
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
blocklineoffset += peerblocklines * RANKS;
}
if (nblock >= headstart) {
#define UNROLLRS 2
const int remainder = min(numlines - rblocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
rblocklineoffset += peerblocklines * RANKS;
const int ibblocklines = blocklines / ibranks;
int4 *tempbufptr = &internalbuf[((nblock - headstart) % maxcredit) * peerblocklines];
const int tempstart = lineoffset + (nblock - headstart) * peerblocklines * RANKS +
myrank * blocklines + ibblocklines * myibrank;
// if(threadIdx.x==32) printf("[%d] block%d thread %d offset %d line %d ibblocklines %d ptr
// %lx commbufoffset
// %d\n",myrank,blockIdx.x,threadIdx.x,tempstart,0,ibblocklines,(void*)&tempbufptr[(1-myibrank)*ibblocklines],(1-myibrank)*ibblocklines*16);
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
for (int line = threadIdx.x - 32 + REDUCETHREADS * blockIdx.x; line < ibblocklines;
line += REDUCETHREADS * gridDim.x) {
int4 val[UNROLLRS];
#pragma unroll
for (int i = 0; i < UNROLLRS; i++)
val[i] = i == myibrank ? userptrmyrank[tempstart + line]
: tempbufptr[i * ibblocklines + line];
int4 sum = val[0];
half *s = reinterpret_cast<half *>(&sum);
for (int i = 0; i < ibranks - UNROLLRS; i++) {
val[i % UNROLLRS] = i == myibrank ? userptrmyrank[tempstart + line]
: tempbufptr[i * ibblocklines + line];
half *x = reinterpret_cast<half *>(&val[(i + 1) % UNROLLRS]);
#pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j];
}
#pragma unroll
for (int i = 1; i < UNROLLRS; i++) {
half *x = reinterpret_cast<half *>(&val[i]);
#pragma unroll
for (int j = 0; j < 16; j++) s[j] += x[j];
}
userptrmyrank[tempstart + line] = sum;
}
asm volatile("bar.sync 13, %0;" ::"r"(REDUCETHREADS + 32));
}
} // nblock loop NVLINK-REDUCESCATTER + IBREDUCE LOCAL COMPUTE
} // worker warps else block
} // fp16 inplace reduce kernel with SHARP / in blocks
template <int RANKS>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag(
const int op, const int maxcredit, const int headstart, const int myibrank, const int ibranks,
const int commbufoffset, const int flagoffset, const int firstrank, const int myrank,
const int gpustep, const int lineoffset, const int numlines, void **commbuff,
const int handleridx, const int peerblocklines, int *hostflags, int *gpuflag,
const int numblocks) {
const int basecounter = gpuflag[NVTE_GF_STATE + op];
if (threadIdx.x < 32) {
int *flagptr;
volatile int *localflag = (volatile int *)&(
((int *)commbuff[gpustep * myrank + firstrank])[flagoffset]); // NOLINT(*)
if (threadIdx.x < RANKS) {
if (!blockIdx.x) {
flagptr = reinterpret_cast<int *>(commbuff[gpustep * threadIdx.x + firstrank]);
}
}
__syncthreads();
if (!blockIdx.x && !threadIdx.x)
hostflags[NVTE_HF_NVREDUCEDONE + (op & 1)] = numblocks + basecounter;
// tell CPU proxy all blocks are done and ready for NVAG
// final part doing NVAG based on responses from NIC-RMW:IBAG
if (blockIdx.x == 0) {
for (int nblock = 0; nblock < numblocks; nblock++) {
const int expected = basecounter + nblock + 1;
for (int ibflag = threadIdx.x; ibflag < ibranks; ibflag += 32)
if (ibflag != myibrank)
while (localflag[NVTE_REG0_IBAG + ibflag] < expected) {
}
asm volatile("bar.sync 15, %0;" ::"r"(32));
if (threadIdx.x < RANKS)
flagptr[flagoffset + gpustep * myrank + NVTE_MAX_NVLINK + firstrank] = expected;
}
}
if (blockIdx.x == 0 && threadIdx.x == 0) gpuflag[NVTE_GF_STATE + op] = basecounter + numblocks;
} else { // sync warp
// reducethreads
const int warp = blockIdx.x + (threadIdx.x >> 5);
int4 *userptr[RANKS];
int4 *userptrmyrank;
#pragma unroll
for (int i = 0; i < RANKS; i++)
userptr[i] = reinterpret_cast<int4 *>(
commbuff[((i + myrank + warp) & (RANKS - 1)) * gpustep + handleridx + firstrank]);
userptrmyrank = reinterpret_cast<int4 *>(commbuff[gpustep * myrank + handleridx + firstrank]);
__syncthreads();
int blocklineoffset = 0, rblocklineoffset = 0;
if (RANKS != 1) {
const int nwarps = (REDUCETHREADS >> 5) / (RANKS - 1);
const int myblockDim = nwarps << 5;
const int mywarp = ((threadIdx.x - 32) >> 5) / (RANKS - 1);
const int maxthreadIdx = myblockDim * (RANKS - 1) + 32;
const int mydest = (myrank + 1 + ((threadIdx.x - 32) >> 5) % (RANKS - 1)) & (RANKS - 1);
const int mythreadIdx = (mywarp << 5) + (threadIdx.x & 31);
volatile int *flag = (volatile int *)&((reinterpret_cast<int *>(
commbuff[gpustep * myrank + firstrank]))[flagoffset + gpustep * mydest + NVTE_MAX_NVLINK +
firstrank]);
int4 *userptrmydest = userptr[((RANKS << 10) + mydest - myrank - warp) & (RANKS - 1)];
blocklineoffset = 0;
int gathercounter = basecounter + 1;
while (blocklineoffset < numlines) {
const int remainder = min(numlines - blocklineoffset, peerblocklines * RANKS);
const int blocklines = remainder / RANKS;
const int blockstart = lineoffset + blocklineoffset;
#define UNROLL 6
int4 *myptr = &userptrmyrank[blockstart + blocklines * mydest];
int4 *peerptr = &userptrmydest[blockstart + blocklines * mydest];
if (threadIdx.x < maxthreadIdx) {
const int start_elem = mythreadIdx + myblockDim * blockIdx.x;
const int end_elem = max(start_elem, blocklines);
const int aligned_elem = ((end_elem - start_elem) / (myblockDim * gridDim.x * UNROLL)) *
(myblockDim * gridDim.x * UNROLL);
const int end_aligned = start_elem + aligned_elem;
if (mythreadIdx == 0) {
while (*flag < gathercounter) {
}
gathercounter++;
}
asm volatile("bar.sync %0, %1;" ::"r"(1 + mydest), "r"(myblockDim));
for (int line = start_elem; line < end_aligned; line += myblockDim * gridDim.x * UNROLL) {
int4 val[UNROLL];
#pragma unroll
for (int i = 0; i < UNROLL; i++) val[i] = peerptr[line + i * myblockDim * gridDim.x];
#pragma unroll
for (int i = 0; i < UNROLL; i++) myptr[line + i * myblockDim * gridDim.x] = val[i];
}
for (int line = end_aligned; line < end_elem; line += myblockDim * gridDim.x)
myptr[line] = peerptr[line];
}
blocklineoffset += peerblocklines * RANKS;
} // block loop for NVLINK-ALLGATHER
} // RANKS!=1
} // worker warps else block
} // fp16 inplace reduce kernel with SHARP / in blocks
__global__ void userbuffers_fp16_sum_inplace_gpu_null(const int op, int *hostflags, int *gpuflag,
int numblocks) {
const int basecounter = gpuflag[NVTE_GF_STATE + op] + numblocks;
hostflags[0] = basecounter;
gpuflag[NVTE_GF_STATE + op] = basecounter;
while (((volatile int *)gpuflag)[NVTE_GF_IBSHARPDONE] < basecounter) {
}
}
#define callranks_block(x) \
if (comm->ar_nvsize == x) \
userbuffers_fp16_sum_inplace_gpu_rr_blocked<x><<<sms, warps * 32, 0, stream>>>( \
userbuffers_allreduceop_sharp, NVTE_REG0_OFFSET(comm), comm->ar_firstgpu, comm->ar_nvrank, \
offset / 8, elements / 8, reinterpret_cast<void **>(comm->gpu_ptrs), \
handler * comm->nvsize, blocksize / sizeof(int4) / comm->ar_nvsize, \
reinterpret_cast<int *>(comm->hostflags), comm->flags, \
(elements * 2 + blocksize - 1) / blocksize);
#define callranks2_block(x) \
if (ar_nvsize == x) { \
int numblocks = (elements * 2 + blocksize - 1) / blocksize; \
int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \
if (headstart > maxcredit) headstart = maxcredit; \
if (x == 1) headstart = maxcredit; \
if (headstart > numblocks) headstart = numblocks; \
if (headstart == 0) headstart = 1; \
userbuffers_fp16_sum_inplace_gpu_rr_blocked2<x><<<sms, warps * 32, 0, stream>>>( \
op, maxcredit, headstart, my_node, num_nodes, \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \
(op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \
offset / 8, elements / 8, reinterpret_cast<void **>(comm->gpu_ptrs), \
handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \
reinterpret_cast<int *>(comm->hostflags), comm->flags, numblocks); \
}
#define callranks2_block_rs(x) \
if (ar_nvsize == x) { \
int numblocks = (elements * 2 + blocksize - 1) / blocksize; \
int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \
if (headstart > maxcredit) headstart = maxcredit; \
if (x == 1) headstart = maxcredit; \
if (headstart > numblocks) headstart = numblocks; \
if (headstart == 0) headstart = 1; \
userbuffers_fp16_sum_inplace_gpu_rr_blocked2_rs<x><<<sms, warps * 32, 0, stream>>>( \
op, maxcredit, headstart, my_node, num_nodes, \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \
(op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \
offset / 8, elements / 8, reinterpret_cast<void **>(comm->gpu_ptrs), \
handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \
reinterpret_cast<int *>(comm->hostflags), comm->flags, numblocks); \
}
#define callranks2_block_ag(x) \
if (ar_nvsize == x) { \
int numblocks = (elements * 2 + blocksize - 1) / blocksize; \
int headstart = numblocks - 1; /*<3?numblocks-1:3;*/ \
if (headstart > maxcredit) headstart = maxcredit; \
if (x == 1) headstart = maxcredit; \
if (headstart > numblocks) headstart = numblocks; \
if (headstart == 0) headstart = 1; \
userbuffers_fp16_sum_inplace_gpu_rr_blocked2_ag<x><<<sms, warps * 32, 0, stream>>>( \
op, maxcredit, headstart, my_node, num_nodes, \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_FLAGS + \
(op == userbuffers_allreduceop_nonsharp ? NVTE_REG0_COMMBUFFER : 0), \
NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * op, ar_firstgpu, ar_nvrank, ar_step, \
offset / 8, elements / 8, reinterpret_cast<void **>(comm->gpu_ptrs), \
handler * comm->nvsize, blocksize / sizeof(int4) / ar_nvsize, \
reinterpret_cast<int *>(comm->hostflags), comm->flags, numblocks); \
}
#define callranks(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg6 = offset / 8, \
arg7 = elements / 8; \
void **arg8 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr<x> \
: userbuffers_fp16_sum_inplace_gpu_rw<x>), \
kernelArgs)); \
}
#define SETUP_LAUNCH_CONFIG(sms, threads, stream) \
cudaLaunchConfig_t cfg = {sms, threads, 0, stream, NULL, 0}; \
cudaLaunchAttribute attribute_ub[2]; \
attribute_ub[1].id = cudaLaunchAttributeClusterDimension; \
attribute_ub[1].val.clusterDim.x = sms % comm->cga_size == 0 ? comm->cga_size : 1; \
attribute_ub[1].val.clusterDim.y = 1; \
attribute_ub[1].val.clusterDim.z = 1; \
attribute_ub[0].id = cudaLaunchAttributeCooperative; \
cfg.attrs = attribute_ub; \
cfg.numAttrs = comm->sm_arch >= 9 ? 2 : 1;
int allreduce_userbuff_inplace_gpu(const int handler, const int offset, const int elements,
const int blocksize, communicator *comm, cudaStream_t stream) {
// schedule GPU kernel only
// CPU/SHARP part is responsibility of caller
const int ar_step = comm->ar2_nvsize;
const int op = userbuffers_allreduceop_nonsharp;
const int ar_nvsize = comm->nvsize;
const int ar_firstgpu = comm->ar_firstgpu;
const int ar_nvrank = comm->ar_nvrank;
if (elements < 8) return 0;
int sms = sms = comm->sms;
int warps = comm->threads / 32;
if (warps < comm->ar_nvsize) warps = comm->ar_nvsize;
if (comm->launch_mode & NVTE_LAUNCH_GPU) {
if (comm->ar_nvsize == 1)
userbuffers_fp16_sum_inplace_gpu_null<<<1, 1, 0, stream>>>(
userbuffers_allreduceop_sharp, reinterpret_cast<int *>(comm->hostflags), comm->flags,
(elements * 2 + blocksize - 1) / blocksize);
callranks_block(2) callranks_block(4) callranks_block(8)
}
return sms;
}
int allreduce2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op) {
// schedule GPU kernel only
// CPU/SHARP part is responsibility of caller
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8) return 0;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;
if (num_nodes > 1) {
callranks2_block(1) callranks2_block(2) callranks2_block(4) callranks2_block(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks(2) callranks(4) callranks(8)
}
return sms;
}
#define callranks_ag(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8 + (comm->use_rr_kernel ? 0 : arg4 * arg7); \
void **arg8 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, \
reinterpret_cast<void *>(comm->use_rr_kernel ? userbuffers_fp16_sum_inplace_gpu_rr_ag<x> \
: userbuffers_fp16_sum_inplace_gpu_rw_ag<x>), \
kernelArgs)); \
}
#define callranks_rs(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8 + arg4 * arg7; \
void **arg8 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg9 = handler * comm->nvsize; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs<x>), kernelArgs)); \
}
#define callranks_rs_oop(x) \
if (ar_nvsize == x) { \
int arg1 = op - NVTE_MAX_OPS, \
arg2 = NVTE_REG0_OFFSET(comm) - \
(op == userbuffers_allreduceop_nonsharp ? 2 : 1) * NVTE_REG0_SINGLENODE + \
NVTE_MAX_OPS, \
arg3 = ar_firstgpu, arg4 = ar_nvrank, arg5 = ar_step, arg7 = elements / 8 / x, \
arg6 = offset / 8 + arg4 * arg7, arg8 = rowelements / 8, arg9 = strideelements / 8; \
void **arg10 = reinterpret_cast<void **>(comm->gpu_ptrs); \
int arg11 = handler * comm->nvsize; \
void *arg12 = output; \
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2), \
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4), \
reinterpret_cast<void *>(&arg5), reinterpret_cast<void *>(&arg6), \
reinterpret_cast<void *>(&arg7), reinterpret_cast<void *>(&arg8), \
reinterpret_cast<void *>(&arg9), reinterpret_cast<void *>(&arg10), \
reinterpret_cast<void *>(&arg11), reinterpret_cast<void *>(&arg12)}; \
CUDACHECK(cudaLaunchKernelExC( \
&cfg, reinterpret_cast<void *>(userbuffers_fp16_sum_inplace_gpu_rr_rs_oop<x>), \
kernelArgs)); \
}
int reducescatter2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op) {
// schedule GPU kernel only
// CPU/SHARP part is responsibility of caller
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8) return 0;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;
if (num_nodes > 1) {
callranks2_block_rs(1) callranks2_block_rs(2) callranks2_block_rs(4) callranks2_block_rs(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
return sms;
}
int allgather2_userbuff_inplace_gpu(const int maxcredit, const int handler, const int offset,
const int elements, const int blocksize, communicator *comm,
cudaStream_t stream, int op) {
// schedule GPU kernel only
// CPU/SHARP part is responsibility of caller
const int num_nodes = op == userbuffers_allreduceop_nonsharp ? comm->num_nodes : comm->num2_nodes;
const int my_node = op == userbuffers_allreduceop_nonsharp ? comm->my_node : comm->my2_node;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 8) return 0;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;
if (num_nodes > 1) {
callranks2_block_ag(1) callranks2_block_ag(2) callranks2_block_ag(4) callranks2_block_ag(8)
} else {
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
return sms;
}
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements * 2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_ag(2) callranks_ag(4) callranks_ag(8)
}
void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements,
communicator *comm, const int slice_id, const int nslices,
cudaStream_t stream) {
const int op = userbuffers_allreduceop_nonsharp2;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
int peerelements = elements / ar_nvsize;
int saverrkernel = comm->use_rr_kernel;
comm->use_rr_kernel = 0;
allgather2_userbuff_inplace(
handler, offset + ar_nvrank * peerelements * (nslices - 1) + slice_id * peerelements,
elements, comm, stream);
comm->use_rr_kernel = saverrkernel;
}
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements * 2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs(2) callranks_rs(4) callranks_rs(8)
}
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream) {
const int elements = rowelements * colelements;
const int op = userbuffers_allreduceop_nonsharp2;
const int blocksize = elements * 2;
const int ar_firstgpu =
op == userbuffers_allreduceop_nonsharp ? comm->ar_firstgpu : comm->ar2_firstgpu;
const int ar_step = op == userbuffers_allreduceop_nonsharp2 ? 1 : comm->ar2_nvsize;
const int ar_nvsize = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvsize : comm->ar2_nvsize;
const int ar_nvrank = op == userbuffers_allreduceop_nonsharp ? comm->ar_nvrank : comm->ar2_nvrank;
if (elements < 64) return;
int sms = ar_nvsize == 1 ? 2 : comm->sms;
int warps = comm->threads / 32;
if (warps < ar_nvsize) warps = ar_nvsize;
SETUP_LAUNCH_CONFIG(sms, warps * 32, stream);
callranks_rs_oop(2) callranks_rs_oop(4) callranks_rs_oop(8)
}
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream) {
reducescatter2_userbuff_stridedoutput(output, handler, offset, elements, 1, 0, comm, stream);
}
__global__ void kuserbuffers_pullsend(int myrank, int peer, int *send_id, int *flagptr) {
atomicAdd(flagptr, 1);
}
__global__ void kuserbuffers_inc(int *id) {
const int signal_id = (*id) + 1;
*id = signal_id;
}
__global__ void kuserbuffers_proxysend(int *id, int *hostflag) {
const int signal_id = (*id) + 1;
*hostflag = signal_id;
*id = signal_id;
}
__global__ void kuserbuffers_dummy(void) {}
__global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_pullrecv(int myrank, int peer, int *recv_id, int *flagptr, int4 *srcptr,
int4 *dstptr, const int lines) {
#define UNROLLCOPY 8
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = lines;
const int aligned_elem = (end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1));
const int end_aligned = start_elem + aligned_elem;
if (threadIdx.x == 0) {
const int signal_id = (*recv_id) + 1;
volatile int *flag = (volatile int *)flagptr;
clock_t s = clock64();
while (*flag < signal_id) {
if (clock64() - s > TIMEOUT) {
printf("[%d from %d] pullrecv: expected %d, stuck with %d\n", myrank, peer, signal_id,
*flag);
break;
}
}
if (lines == 0) {
*recv_id = signal_id;
return;
} // otherwise need an extra kernel
}
__syncthreads();
if (end_elem <= start_elem) return;
for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i];
}
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x)
dstptr[line] = srcptr[line];
}
__global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_pushsend(int *send_id, int *flagptr, int4 *srcptr, int4 *dstptr, const int lines) {
if (lines) {
const int start_elem = threadIdx.x + blockDim.x * blockIdx.x;
const int end_elem = lines;
const int aligned_elem =
((end_elem - start_elem) & (~(blockDim.x * gridDim.x * UNROLLCOPY - 1)));
const int end_aligned = start_elem + aligned_elem;
if (end_elem > start_elem) {
for (int line = start_elem; line < end_aligned; line += blockDim.x * gridDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x * gridDim.x];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x * gridDim.x] = val[i];
}
for (int line = end_aligned; line < end_elem; line += blockDim.x * gridDim.x)
dstptr[line] = srcptr[line];
}
__syncthreads();
if (threadIdx.x) return;
__threadfence_system();
atomicAdd(flagptr, 1); // otherwise need local SM sync before sending flag
} else { // 0 bytes and 1 SM only
atomicAdd(flagptr, 1);
}
}
__global__ void kuserbuffers_pushrecv(int myrank, int peer, int *recv_id, int *flagptr, int adder) {
const int signal_id = (*recv_id) + adder;
*recv_id = signal_id;
volatile int *flag = (volatile int *)flagptr;
if (*flag >= signal_id) return;
clock_t s = clock64();
while (*flag < signal_id) {
if (clock64() - s > TIMEOUT) {
printf("%d from %d] pushrecv: expected %d, stuck with %d\n", myrank, peer, signal_id, *flag);
return;
}
}
}
#define CUDACHECK(cmd) \
do { \
cudaError_t e = cmd; \
if (e != cudaSuccess) { \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while (0)
#define INTRANODE(peer) ((peer / comm->nvsize) == (comm->myrank / comm->nvsize))
void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream) {
int peerlocal = peer % comm->nvsize;
void *flagptr =
(comm->peer_ptr[0][peerlocal]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + comm->myrank * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
bool intranode = INTRANODE(peer);
if (!intranode && (comm->launch_mode & NVTE_LAUNCH_CPU)) {
comm->fifo[comm->head].optype = userbuffers_sendop;
comm->fifo[comm->head].basecounter = comm->basecounter[userbuffers_sendop];
comm->fifo[comm->head].handler = srchandler;
comm->fifo[comm->head].offset = srcoffset;
comm->fifo[comm->head].handler2 = dsthandler;
comm->fifo[comm->head].offset2 = dstoffset;
comm->fifo[comm->head].elements = bytes;
comm->fifo[comm->head].peer = peer;
int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1);
while (newhead == comm->tail) {
}
comm->head = newhead;
comm->basecounter[userbuffers_sendop] += 1;
}
if (!intranode && (comm->launch_mode & NVTE_LAUNCH_GPU)) {
kuserbuffers_proxysend<<<1, 1, 0, stream>>>(&(comm->flags[NVTE_GF_STATE + userbuffers_sendop]),
comm->hostflags + userbuffers_sendop);
return;
}
if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return;
if (comm->push == 0) {
kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]),
reinterpret_cast<int *>(flagptr));
} else {
void *srcptr = (comm->mem_ptr[srchandler]) + srcoffset;
void *dstptr = (comm->peer_ptr[dsthandler][peerlocal]) + dstoffset;
if (comm->use_ce)
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
SETUP_LAUNCH_CONFIG(signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, stream);
int *arg1 = &comm->send_id[peer], *arg2 = reinterpret_cast<int *>(flagptr);
int4 *arg3 = reinterpret_cast<int4 *>(srcptr), *arg4 = reinterpret_cast<int4 *>(dstptr);
int arg5 = signalonly ? 0 : bytes / 16;
void *kernelArgs[] = {reinterpret_cast<void *>(&arg1), reinterpret_cast<void *>(&arg2),
reinterpret_cast<void *>(&arg3), reinterpret_cast<void *>(&arg4),
reinterpret_cast<void *>(&arg5)};
CUDACHECK(
cudaLaunchKernelExC(&cfg, reinterpret_cast<void *>(kuserbuffers_pushsend), kernelArgs));
}
}
__global__ void __launch_bounds__(MAX_THREADS)
kuserbuffers_alltoall(void **baseflagptrs, int flagoffset, int4 *basesrcptr, void **dstptrs,
size_t dstoffset, const int lines, const int myrank) {
if (blockIdx.x == myrank) return;
int4 *dstptr = reinterpret_cast<int4 *>(dstptrs[blockIdx.x] + dstoffset);
int *flagptr = reinterpret_cast<int *>(baseflagptrs[blockIdx.x] + flagoffset);
const size_t myblockoffset = blockIdx.x * lines;
int4 *srcptr = basesrcptr + myblockoffset;
dstptr += myblockoffset;
if (lines) {
const int start_elem = threadIdx.x;
const int end_elem = lines;
const int aligned_elem = ((end_elem - start_elem) & (~(blockDim.x * UNROLLCOPY - 1)));
const int end_aligned = start_elem + aligned_elem;
if (end_elem > start_elem) {
for (int line = start_elem; line < end_aligned; line += blockDim.x * UNROLLCOPY) {
int4 val[UNROLLCOPY];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) val[i] = srcptr[line + i * blockDim.x];
#pragma unroll
for (int i = 0; i < UNROLLCOPY; i++) dstptr[line + i * blockDim.x] = val[i];
}
for (int line = end_aligned; line < end_elem; line += blockDim.x) dstptr[line] = srcptr[line];
}
__syncthreads();
if (threadIdx.x) return;
__threadfence_system();
atomicAdd(flagptr, 1);
} else {
atomicAdd(flagptr, 1);
}
}
void userbuffers_alltoall_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
cudaStream_t stream) {
if (comm->launch_mode & NVTE_LAUNCH_CPU) {
comm->fifo[comm->head].optype = userbuffers_alltoall;
comm->fifo[comm->head].basecounter = comm->basecounter[userbuffers_alltoall];
comm->fifo[comm->head].handler = srchandler;
comm->fifo[comm->head].offset = srcoffset;
comm->fifo[comm->head].handler2 = dsthandler;
comm->fifo[comm->head].offset2 = dstoffset;
comm->fifo[comm->head].elements = bytes;
int newhead = (comm->head + 1) & (NVTE_MAX_REQUESTS - 1);
while (newhead == comm->tail) {
}
comm->head = newhead;
comm->basecounter[userbuffers_alltoall] += 1;
}
if (comm->launch_mode & NVTE_LAUNCH_GPU)
kuserbuffers_proxysend<<<1, 1, 0, stream>>>(
&(comm->flags[NVTE_GF_STATE + userbuffers_alltoall]),
comm->hostflags + userbuffers_alltoall);
}
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream) {
int peerlocal = peer % comm->nvsize;
void *flagptr =
(comm->mem_ptr[0]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_RECV + peer * NVTE_MAX_REGIONS + dsthandler) *
sizeof(int));
bool signalonly = (bytes / 16 == 0) || (comm->use_ce != 0);
bool intranode = INTRANODE(peer);
if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return;
if (comm->push == 0 && intranode) {
void *dstptr = (comm->mem_ptr[dsthandler]) + dstoffset;
void *srcptr = (comm->peer_ptr[srchandler][peerlocal]) + srcoffset;
kuserbuffers_pullrecv<<<signalonly ? 1 : comm->sms, signalonly ? 1 : 1024, 0, stream>>>(
comm->myrank, peer, &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]),
reinterpret_cast<int *>(flagptr), reinterpret_cast<int4 *>(srcptr),
reinterpret_cast<int4 *>(dstptr), signalonly ? 0 : bytes / 16);
if (!signalonly)
kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]));
if (comm->use_ce) {
CUDACHECK(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream));
}
} else {
kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(
comm->myrank, peer, &comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler],
reinterpret_cast<int *>(flagptr), signalonly || !intranode ? 1 : comm->sms);
}
}
void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream) {
void *flagptr =
(comm->mem_ptr[0]) +
((NVTE_REG0_OFFSET(comm) + NVTE_REG0_OPFLAGS * userbuffers_alltoall) * sizeof(int));
if (!(comm->launch_mode & NVTE_LAUNCH_GPU)) return;
kuserbuffers_pushrecv<<<1, 1, 0, stream>>>(comm->myrank, -1, reinterpret_cast<int *>(flagptr + 4),
reinterpret_cast<int *>(flagptr), comm->nranks - 1);
}
......@@ -49,6 +49,7 @@ void cublas_gemm(const Tensor *inputA,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream
) {
void *A = inputA->data.dptr;
......@@ -124,6 +125,13 @@ void cublas_gemm(const Tensor *inputA,
&transa, sizeof(transa)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_TRANSB,
&transb, sizeof(transb)));
// Set math SM count
if (math_sm_count != 0) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
&math_sm_count, sizeof(math_sm_count)));
}
// set fp8 attributes -- input and output types should already be set to fp8 as appropriate
// Note: gelu fusion isn't available right now, and we don't need
......@@ -227,6 +235,7 @@ void cublas_gemm(const Tensor *inputA,
if (returnedResults == 0) throw std::runtime_error("Unable to find any suitable algorithms");
// D = alpha * (A * B) + beta * C
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle,
operationDesc,
static_cast<const void*>(&one), /* alpha */
......@@ -266,6 +275,7 @@ void nvte_cublas_gemm(const NVTETensor A,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream) {
NVTE_API_CALL(nvte_cublas_gemm);
using namespace transformer_engine;
......@@ -308,5 +318,6 @@ void nvte_cublas_gemm(const NVTETensor A,
grad, wspace->data.dptr,
wspace->data.shape[0],
accumulate, use_split_accumulator,
math_sm_count,
stream);
}
......@@ -36,6 +36,7 @@ extern "C" {
* \param[out] workspace Workspace tensor.
* \param[in] accumulate Whether to accumulate the result into the D matrix.
* \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM.
* \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics)
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_cublas_gemm(const NVTETensor A,
......@@ -49,6 +50,7 @@ void nvte_cublas_gemm(const NVTETensor A,
NVTETensor workspace,
bool accumulate,
bool use_split_accumulator,
int math_sm_count,
cudaStream_t stream
);
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_USERBUFFERS_H_
#define TRANSFORMER_ENGINE_USERBUFFERS_H_
#include <cuda.h>
#include <mpi.h>
#include "cuda_runtime.h"
#include <pthread.h>
#include <chrono>
#include "gdrapi.h"
#include <stdexcept>
#define NVTE_MAX_REGIONS 16
#define NVTE_MAX_SMS 32
#define NVTE_MAX_OPS 32
#define NVTE_MAX_PEERS 8192
#define NVTE_MAX_REQUESTS 1024
#define NVTE_LAUNCH_GPU 1
#define NVTE_LAUNCH_CPU 2
#define NVTE_MAX_NVLINK 8
// region 0 flag offsets
#define NVTE_REG0_OPFLAGS 1024
#define NVTE_REG0_RECV (NVTE_REG0_OPFLAGS * userbuffers_op_types)
#define NVTE_REG0_SINGLENODE (2 * NVTE_MAX_NVLINK * NVTE_MAX_SMS + NVTE_MAX_OPS)
#define NVTE_REG0_OFFSET(comm) ((2 * NVTE_MAX_REGIONS) * NVTE_MAX_NVLINK \
+ NVTE_REG0_SINGLENODE * 2 + NVTE_MAX_PEERS)
#define NVTE_REG0_COMMBUFFER 0
#define NVTE_REG0_FLAGS (NVTE_REG0_RECV + NVTE_MAX_PEERS * NVTE_MAX_REGIONS)
#define NVTE_REG0_IBRS 32
#define NVTE_REG0_IBAG 512
#undef NVTE_REG0_COMMBUFFER
#define NVTE_REG0_COMMBUFFER (1024 * 1024 * 16)
// gpuflags map offsets
#define NVTE_GF_STATE 16000
#define NVTE_GF_IBSHARPDONE 0
#define NVTE_HF_NVRSDONE (userbuffers_op_types + 1)
#define NVTE_HF_NVREDUCEDONE (userbuffers_op_types + 3)
#define NVTE_MAX_SHARP 16
typedef struct ub_request {
int optype;
int blocksize;
int basecounter;
int elements;
int handler;
int handler2;
size_t offset;
size_t offset2;
int peer;
// ----execution states
int active, maxcredit;
int nblock, numblocks, unconfirmed_ib_in_flight;
} ub_request;
enum req_type {
userbuffers_allreduceop_sharp,
userbuffers_sendop,
userbuffers_allreduceop_nonsharp,
userbuffers_allreduceop_nonsharp2,
userbuffers_alltoall,
userbuffers_op_types
};
struct communicator {
int myrank, nranks; // global job communicator
int nvrank, nvsize; // single node comm_intra
int free_region;
int launch_mode;
void *gpu_ptrs;
int sms, threads;
int use_rr_kernel; // Whether to use RR (or RW) for NVLink-only kernel
int cga_size;
int push, use_ce;
void *mem_ptr[NVTE_MAX_REGIONS];
void **peer_ptr[NVTE_MAX_REGIONS];
int ar_nvsize, ar_firstgpu,
ar_nvrank; // number of gpus(and first gpu in a group) of gpus per node in reduction subgroup
// (_splitar init used) would be equal to (nvsize,0) for regular comm_create
int ar2_nvsize, ar2_firstgpu, ar2_nvrank; // with ar_nvsize as a step
int pipe_id; // which allreduce set of groups (pipeline rank in range of 0..pipeline_size)
int sm_arch;
int num_nodes, my_node,
first_node; // comm_inter communicator, per-rail allreduce (might have subset of nodes)
int num2_nodes, my2_node, first2_node; // with num_nodes as a stride
// max value for running block counters in hostflags
int basecounter[userbuffers_op_types]; // NOLINT(*)
int *hostflags;
int *flags, *map_flags;
gdr_t g;
struct sharp_coll_context *sharp_coll_context;
struct sharp_coll_comm *sharp_coll_comm;
void *mem_mr[NVTE_MAX_REGIONS];
ub_request *fifo;
volatile int activeproxy;
int nblocks, alignblock, minblock, asyncblocks, active_nreqs;
ub_request active_req[userbuffers_op_types]; // NOLINT(*)
int padding[7];
volatile int head;
int padding2[15];
volatile int tail;
MPI_Request mpihndl[NVTE_MAX_SHARP];
MPI_Comm comm_inter, // reduction group communicator (subset of the nodes) along GPU rail
comm_intra; // full intranode (all ndev GPUS)
int ibnvsize; // can be used to fake smaller or larger nvlink domain to use ib instead of nvlink
// or force MNNVL
int *send_id, *recv_id;
int mydev;
};
typedef struct communicator communicator;
int create_communicator(communicator **comm);
/* creates communicator, allocates all internal buffers if necessary */
int create_communicator_grouped(communicator **comm, int pipegpus, int pipenodes);
int create_communicator_grouped2(communicator **comm, int pipegpus, int pipenodes, int tensorgpus,
int tensornodes);
/* creates communicator with
allreduce1 to happen in datagpus x datanodes groups,
allreduce2 to happen in tensorgpus x tensor nodes,
where num_nodes = pipenodes x tensornodes x datanodes
nvlink_size = pipegpus x tensorgpus x datagpus
*/
// int check_user_buffer_registration(void* gpubuff, int bytes, communicator* comm, size_t* offset);
/*
local calls, doesnt communicate between peers
returns handler if buffer is registered already, or -1 if not.
returned offset is offset of gpubuff relative to buffer registered
*/
int pipe_rank(communicator *comm,
int step); // helper function to help walk across allreduce1 x allreduce2 groups
// data-parallel and tensor-parallel position within data and tensor
// groups would be preserved
int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm,
bool alloc = false);
/* returns handler and registers buffers. assumed to be collective i.e. you use same groups and
dont mix buffers for different operations returns -1 if cant register (too many preregistered
regions already) if alloc==true will allocate memory and fill the pointers (required for NVL
SHARP and NSO/MNNVL)
*/
void allreduce_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
// for DP distributed optimizer, only nonSHARP multinode is implemented & calls must come in pairs
// ordered
void allgather_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void reducescatter_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void allreduce2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
// for TP-parallelism, only single node is implemented
void allgather2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void allgather2_userbuff_inplace_sliced(const int handler, const int offset, const int elements,
communicator *comm, const int slice_id, const int nslices,
cudaStream_t stream = 0);
/*
each Rank input is
allgather2_userbuff_inplace: offset+myrank*elements
allgather2_userbuff_inplace_sliced: offset+myrank*elements*nslices+slice_id*elements
equivalent codes would be:
for(int slice=0;slice<ncslices;slice++)
allgather2_userbuff_inplace_sliced(hndl,offset, elements,comm,slice,nslices,stream);
and
allgather2_userbuff_inplace(hndl,offset, elements*nslices,comm,stream);
*/
void reducescatter2_userbuff_inplace(const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void reducescatter2_userbuff(void *output, const int handler, const int offset, const int elements,
communicator *comm, cudaStream_t stream = 0);
void reducescatter2_userbuff_stridedoutput(void *output, const int handler, const int offset,
const int rowelements, const int colelements,
const int strideelements, communicator *comm,
cudaStream_t stream = 0);
/* everything should be 16byte aligned = 8 elts aligned
output is strided: row starts separated by stride elements*/
/* inplace allreduce: works only with buffers registered by previous call. offset should be same
* for all peers */
// two matching pairs, intended to work as push from sender or pull by receiver
// either way signal is a write by sender meaning
// push model: data arrived and visible at receiver(barrier enforced)
// pull model: data ready to be pulled by receiver(no barrier needed)
void userbuffers_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream = 0);
void userbuffers_recv(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
const int peer, cudaStream_t stream = 0);
// alltoall split send and recv to allow for overlap
// send kicks in sending data to the destination - invoke on same stream as data generation
// recv returns once data has received
// send and recv can be on different streams
void userbuffers_alltoall_send(const int srchandler, const size_t srcoffset, const int dsthandler,
const size_t dstoffset, const size_t bytes, communicator *comm,
cudaStream_t stream = 0);
void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0);
// void unregister_user_buffer(int handler);
void destroy_communicator(communicator *comm);
#endif // TRANSFORMER_ENGINE_USERBUFFERS_H_
......@@ -267,7 +267,7 @@ void Gemm(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque
nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), null_tensor.data(),
null_tensor.data(), (desc.transa) ? CUBLAS_OP_T : CUBLAS_OP_N,
(desc.transb) ? CUBLAS_OP_T : CUBLAS_OP_N, false, wk_tensor.data(), false,
desc.use_split_accumulator, stream);
desc.use_split_accumulator, 0, stream);
}
void LayerNormForwardImpl(size_t n, size_t hidden, bool zero_centered_gamma, float eps, void *input,
......
......@@ -29,6 +29,9 @@ def fp8_gemm(
use_bias: bool = False,
use_split_accumulator: bool = False,
D_dtype: Optional[tex.DType] = None,
ub_algo: tex.UbufOverlapAlgo = None,
ub: Union[tex.UbufCommOverlap, tex.UbufP2PCommOverlap] = None,
extra_output_tensor: torch.Tensor = None,
) -> torch.Tensor:
"""TN layout GEMM with fp8 inputs."""
......@@ -55,7 +58,7 @@ def fp8_gemm(
out_dtype = TE_DType[out.dtype] if D_dtype is None else D_dtype
_ = torch.ops.tex_ts.te_gemm_ts(
args = (
A,
A_scale_inv,
A_fp8_tensor,
......@@ -77,8 +80,29 @@ def fp8_gemm(
workspace,
workspace.shape[0],
accumulate,
use_split_accumulator,
use_split_accumulator)
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (True, extra_output_tensor,))
_ = fn(*args)
if return_output:
if gelu:
......@@ -102,6 +126,9 @@ def gemm(
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
use_bias: bool = False,
ub_algo: tex.UbufOverlapAlgo = None,
ub: tex.UbufCommOverlap = None,
extra_output_tensor: torch.Tensor = None,
) -> Tuple[Union[torch.Tensor, None], ...]:
"""Non FP8 GEMM."""
......@@ -142,7 +169,7 @@ def gemm(
else:
bias_dtype = output_dtype
_ = torch.ops.tex_ts.te_gemm_ts(
args = (
A,
empty_tensor,
fp8_index,
......@@ -166,6 +193,28 @@ def gemm(
accumulate,
False, # use_split_accumulator
)
fn = torch.ops.tex_ts.te_gemm_ts
if ub_algo is not None:
assert ub is not None, 'ub object is None!'
if ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_AG:
fn = ub.bulk_overlap
args = tuple(args + (1,))
elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS:
fn = ub.bulk_overlap
args = tuple(args + (0,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG:
fn = ub.split_overlap_ag
extra_output_tensor = (
empty_tensor if extra_output_tensor is None else extra_output_tensor
)
args = tuple(args + (extra_output_tensor,))
elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS:
fn = ub.split_overlap_rs
assert (
extra_output_tensor is not None
), 'SPLIT_PIPELINED_RS requires extra output tensor'
args = tuple(args + (False, extra_output_tensor,))
_ = fn(*args)
if return_output:
return out, grad_bias, gelu_input
......@@ -283,9 +332,25 @@ def layernorm_fwd_fp8(
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
sm_margin: int,
zero_centered_gamma: bool
zero_centered_gamma: bool,
ln_out: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""LayerNorm with FP8 output"""
if ln_out is not None:
return tex.layernorm_fwd_fp8_noalloc(
inp,
weight,
bias,
eps,
fp8_meta_tensor.scale[fp8_tensor],
ln_out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype,
sm_margin,
zero_centered_gamma
)
return tex.layernorm_fwd_fp8(
inp,
weight,
......@@ -351,8 +416,20 @@ def cast_to_fp8(
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
out: Optional[torch.Tensor] = None,
) -> Optional[torch.Tensor]:
"""Cast input to FP8"""
if out is not None:
tex.cast_to_fp8_noalloc(
inp,
fp8_meta_tensor.scale[fp8_tensor],
out,
fp8_meta_tensor.amax_history[0][fp8_tensor],
fp8_meta_tensor.scale_inv[fp8_tensor],
otype
)
return None
return torch.ops.tex_ts.cast_to_fp8_ts(
inp,
fp8_meta_tensor.scale,
......
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <stdio.h>
#include <stdlib.h>
#include <torch/cuda.h>
#include <torch/custom_class.h>
#include <torch/extension.h>
#include <torch/types.h>
#include <transformer_engine/userbuffers.h>
#define HALF_BYTES 2
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
namespace ubuf {
enum class COMM_TYPE { RS = 0, AG = 1 };
enum class UBOverlapAlgo {
BULK_OVERLAP_AG = 0,
BULK_OVERLAP_RS = 1,
SPLIT_PIPELINED_AG = 2,
SPLIT_PIPELINED_RS = 3
};
struct UbufCommOverlap : torch::CustomClassHolder {
communicator *_ub_comm;
int _tp_id;
int _tp_size;
int _num_splits;
int _math_sms;
int _ub_reg;
void *_ubuf_ptr;
torch::Tensor _ubuf;
torch::Tensor output_tensor;
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_d2dcopy, _start_comm, _stop_comm;
UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size,
int num_splits, bool set_sm_margin, int num_max_streams) {
// Initialize userbuf communicator
create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1);
_ub_comm->use_ce = 0;
_ub_comm->sms = num_comm_sm;
_ub_comm->cga_size = comm_cga_size;
// Allocate and register extra userbuffers
int ubuf_bytes = sample.numel() * sample.element_size();
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
for (int i = 0; i < std::min(num_max_streams, num_splits); i++) {
cudaStream_t stream;
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
_stream_compute.push_back(
at::cuda::getStreamFromExternal(stream, stream_main.device_index()));
}
_num_splits = num_splits;
_tp_size = tp_size;
_tp_id = (rank % tp_size);
// Set the number of SMs for GEMM with margin
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount;
output_tensor = torch::Tensor();
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_d2dcopy, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
}
/*
** Bulk GEMM + COMM
** This function assumes the communication input is pre-copied to _ubuf
*/
std::vector<at::Tensor> bulk_overlap(
at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B, at::Tensor B_scale_inverse,
int64_t B_fp8_tensor, transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator, int comm_type) {
// Get the current userbuf offset
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
int comm_elements = (_ubuf.numel() / 2) * _ubuf.element_size(); // UBUF uses 2Byte element size
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type == COMM_TYPE::RS) {
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
}
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_comm, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication: AG and RS
if (_comm_type == COMM_TYPE::AG) {
allgather2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm, (cudaStream_t)_stream_comm);
} else if (_comm_type == COMM_TYPE::RS) {
reducescatter2_userbuff_inplace(_ub_reg, 0, comm_elements, _ub_comm,
(cudaStream_t)_stream_comm);
} else {
NVTE_ERROR("Not supported communication type.");
}
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0);
te_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, D, D_scale,
D_type, D_amax, bias, bias_type, pre_gelu_out, grad, workspace, workspaceSize,
accumulate, use_split_accumulator, _math_sms);
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
// Generate output tensor from userbuf data pointer
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
return {D, output_tensor};
} // bulk_overlap
/*
** Split FPROP GEMM + ReduceScatter
*/
void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax,
at::Tensor bias, transformer_engine::DType bias_type,
at::Tensor pre_gelu_out, bool grad, at::Tensor workspace,
size_t workspaceSize, bool accumulate, bool use_split_accumulator,
bool gemm_overlap, at::Tensor rs_output) {
// Get GEMM dimensions
int m = A.size(0);
int k = A.size(1);
int n = B.size(0);
int m_chunk = m / _num_splits;
int input_a_chunk_size = m_chunk * k;
int output_chunk_size = n * m_chunk;
int workspace_size_chunk = workspaceSize / _stream_compute.size();
// Get input, output, and workspace data pointers
char *input_a_chunk_ptr = reinterpret_cast<char *>(A.data_ptr());
char *output_buf_chunk_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
char *rs_output_ptr = reinterpret_cast<char *>(rs_output.data_ptr());
int ubuf_offset = 0;
// Catch up the default torch stream
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, stream_main));
for (int i = 0; i < _stream_compute.size(); i++) {
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[i], _start_compute, 0));
}
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
assert(pre_gelu_out.numel() == 0);
if (gemm_overlap) {
torch::Tensor input_a_chunk = torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[0]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, _math_sms);
for (int i = 1; i < _num_splits; i++) {
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(
_start_comm, (cudaStream_t)_stream_compute[(i - 1) % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, (i - 1) * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
rs_output_ptr += m_chunk * _ubuf.element_size();
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_start_comm, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg,
(_num_splits - 1) * output_chunk_size, m_chunk, n, m,
_ub_comm, (cudaStream_t)_stream_comm);
} else {
for (int i = 0; i < _num_splits; i++) {
torch::Tensor input_a_chunk =
torch::from_blob(input_a_chunk_ptr, {m_chunk, k}, A.options());
torch::Tensor output_chunk =
torch::from_blob(output_buf_chunk_ptr, {n, m_chunk}, _ubuf.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(input_a_chunk, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
CHECK_CUDA(cudaEventRecord(_start_comm,
(cudaStream_t)_stream_compute[i % _stream_compute.size()]));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_comm, 0));
// Communication chunk
reducescatter2_userbuff_stridedoutput(rs_output_ptr, _ub_reg, i * output_chunk_size,
m_chunk, n, m, _ub_comm, (cudaStream_t)_stream_comm);
rs_output_ptr += m_chunk * _ubuf.element_size();
input_a_chunk_ptr += input_a_chunk_size * B.element_size();
output_buf_chunk_ptr += output_chunk_size * _ubuf.element_size();
}
}
int last_compute_stream_id =
(_num_splits + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_comm, 0));
at::cuda::setCurrentCUDAStream(stream_main);
return;
} // split_overlap_rs
/*
** Helper function to copy input to _ubuf
*/
void copy_input_to_ubuf(torch::Tensor input, int comm_type) {
char *ubuf_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type == COMM_TYPE::AG) {
if ((input.numel() * _tp_size) != _ubuf.numel() ||
input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
ubuf_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
} else {
if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
}
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_d2dcopy, (cudaStream_t)stream_main));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_d2dcopy, 0));
CHECK_CUDA(cudaMemcpyAsync(ubuf_ptr, input.data_ptr(), input.numel() * input.element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
}
torch::Tensor &get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
return output_tensor;
}
}; // UbufCommOverlap
struct UbufP2PCommOverlap : torch::CustomClassHolder {
communicator *_ub_comm;
int _tp_id;
int _tp_size;
int _ub_reg;
int _next_rank, _prev_rank, _rank, _rank_round_tp;
int _aggregate2;
int _math_sms;
void *_ubuf_ptr;
torch::Tensor _ubuf;
std::vector<torch::Tensor> _ubufs;
at::cuda::CUDAStream _stream_comm = at::cuda::getStreamFromPool(true);
std::vector<at::cuda::CUDAStream> _stream_compute;
cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _start_accum, _stop_accum;
UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, bool aggregate2,
int num_max_streams) {
// Initialize userbuf communicator
create_communicator_grouped2(&_ub_comm, 1, 1, tp_size, 1);
_ub_comm->use_ce = 1;
_ub_comm->sms = 1;
_ub_comm->cga_size = 1;
// Create workspace tensor with userbuffer
int ubuf_bytes = sample.numel() * sample.element_size();
int ubuf_chunk_bytes = ubuf_bytes / tp_size;
_ub_reg = register_user_buffer_collective(reinterpret_cast<void **>(&_ubuf_ptr), ubuf_bytes,
_ub_comm, true);
_ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options());
// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
for (int i = 0; i < tp_size; i++) {
torch::Tensor ubuf_chunk = torch::from_blob(
ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options());
_ubufs.push_back(ubuf_chunk);
ubuf_byte_ptr += ubuf_chunk_bytes;
}
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
for (int i = 0; i < std::min(num_max_streams, tp_size); i++) {
cudaStream_t stream;
cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, -1);
_stream_compute.push_back(
at::cuda::getStreamFromExternal(stream, stream_main.device_index()));
}
// Set the number of SMs for GEMM with margin
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
_math_sms = prop.multiProcessorCount;
_tp_size = tp_size;
_aggregate2 = aggregate2;
_rank = rank;
_tp_id = (rank % tp_size);
_rank_round_tp = (rank / tp_size) * tp_size;
_next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp;
_prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp;
// CUDA event creation
cudaEventCreateWithFlags(&_start_compute, 0);
cudaEventCreateWithFlags(&_stop_compute, 0);
cudaEventCreateWithFlags(&_start_comm, 0);
cudaEventCreateWithFlags(&_stop_comm, 0);
cudaEventCreateWithFlags(&_start_accum, 0);
cudaEventCreateWithFlags(&_stop_accum, 0);
}
/*
** Split AllGather + GEMM using P2P communication
** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is needed to have AG
*outputs
** in each rank to be in the contiguous memory space after all ring exchange phases.
*/
torch::Tensor split_overlap_ag(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor,
transformer_engine::DType A_type, bool transa, at::Tensor B,
at::Tensor B_scale_inverse, int64_t B_fp8_tensor,
transformer_engine::DType B_type, bool transb, at::Tensor D,
at::Tensor D_scale, transformer_engine::DType D_type,
at::Tensor D_amax, at::Tensor bias,
transformer_engine::DType bias_type, at::Tensor pre_gelu_out,
bool grad, at::Tensor workspace, size_t workspaceSize,
bool accumulate, bool use_split_accumulator, at::Tensor B_copy) {
// Get GEMM dimensions between TN and NN input layouts
const int m = (transa) ? A.size(0) : A.size(1);
const int k = (transa) ? A.size(1) : A.size(0);
const int n_chunk = _ubufs[0].size(0);
// Get communication and GEMM output chunk sizes
const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size();
const int output_chunk_bytes = (n_chunk * m) * HALF_BYTES;
// Get output and workspace data pointers
char *output_ptr = reinterpret_cast<char *>(D.data_ptr());
char *workspace_ptr = reinterpret_cast<char *>(workspace.data_ptr());
int workspace_size_chunk = workspaceSize / _stream_compute.size();
if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor];
if (B_scale_inverse.numel()) B_scale_inverse = B_scale_inverse[B_fp8_tensor];
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main));
assert(pre_gelu_out.numel() == 0);
if (_aggregate2) {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
const int num_steps = _tp_size / 2;
char *input_b_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
// Initial 1X input chunk exchange between neighboring peers
int send_chunk_id = _tp_id;
int recv_chunk_id = (_tp_id % 2 == 0) ? _tp_id + 1 : _tp_id - 1;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
int peer_rank = (_tp_id % 2 == 0) ? _next_rank : _prev_rank;
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_comm);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm, peer_rank,
(cudaStream_t)_stream_comm);
CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
int local_rank_round2 = (_tp_id % 2 == 0) ? _tp_id : _tp_id - 1;
const int next_rank = (_tp_size + _tp_id + 2) % _tp_size + _rank_round_tp;
const int prev_rank = (_tp_size + _tp_id - 2) % _tp_size + _rank_round_tp;
// Ring exchange of 2X inputs chunks
for (int i = 0; i < num_steps; i++) {
send_chunk_id = (_tp_size + local_rank_round2 - i * 2) % _tp_size;
recv_chunk_id = (_tp_size + local_rank_round2 - i * 2 - 2) % _tp_size;
send_offset = comm_bytes * send_chunk_id;
recv_offset = comm_bytes * recv_chunk_id;
// GEMM
torch::Tensor input_b_chunk =
torch::from_blob(input_b_ptr + send_offset, {n_chunk * 2, k}, _ubuf.options());
torch::Tensor output_chunk = torch::from_blob(
output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk * 2, m}, D.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb,
output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
if (i < num_steps - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes * 2, _ub_comm,
next_rank, (cudaStream_t)_stream_comm);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes * 2, _ub_comm,
prev_rank, (cudaStream_t)_stream_comm);
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id =
(num_steps + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
} else {
// Catch up the default torch stream
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _start_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_compute[0], _start_compute, 0));
for (int i = 0; i < _tp_size; i++) {
// Set the userbuffer id. Buffer under send is the input for the current GEMM chunk
// The initial input chunk is stored _ubuf[rank]. This is to have the AG output in all ranks
// to be contiguous after the ring exchanges
int send_chunk_id = (_tp_size + _tp_id - i) % _tp_size;
int recv_chunk_id = (_tp_size + _tp_id - i - 1) % _tp_size;
int send_offset = comm_bytes * send_chunk_id;
int recv_offset = comm_bytes * recv_chunk_id;
// GEMM
torch::Tensor output_chunk = torch::from_blob(
output_ptr + (send_chunk_id * output_chunk_bytes), {n_chunk, m}, D.options());
torch::Tensor workspace_chunk =
torch::from_blob(workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk,
{workspace_size_chunk}, workspace.options());
at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]);
te_gemm(A, A_scale_inverse, A_type, transa, _ubufs[send_chunk_id], B_scale_inverse, B_type,
transb, output_chunk, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad,
workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator,
_math_sms);
if (i < _tp_size - 1) {
// P2P communication
userbuffers_send(_ub_reg, send_offset, _ub_reg, send_offset, comm_bytes, _ub_comm,
_next_rank, (cudaStream_t)_stream_comm);
userbuffers_recv(_ub_reg, recv_offset, _ub_reg, recv_offset, comm_bytes, _ub_comm,
_prev_rank, (cudaStream_t)_stream_comm);
CHECK_CUDA(cudaEventRecord(_stop_comm, (cudaStream_t)_stream_comm));
CHECK_CUDA(cudaStreamWaitEvent(
(cudaStream_t)_stream_compute[(i + 1) % _stream_compute.size()], _stop_comm, 0));
} else if (B_copy.numel() > 0) {
assert(B_copy.numel() == _ubufs[_tp_id].numel());
assert(B_copy.element_size() == _ubufs[_tp_id].element_size());
CHECK_CUDA(cudaMemcpyAsync(B_copy.data_ptr(), _ubufs[_tp_id].data_ptr(),
_ubufs[_tp_id].numel() * _ubufs[_tp_id].element_size(),
cudaMemcpyDeviceToDevice, (cudaStream_t)_stream_comm));
}
}
at::cuda::setCurrentCUDAStream(stream_main);
int last_compute_stream_id = (_tp_size + _stream_compute.size() - 1) % _stream_compute.size();
CHECK_CUDA(
cudaEventRecord(_stop_compute, (cudaStream_t)_stream_compute[last_compute_stream_id]));
}
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0));
CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_comm, _stop_compute, 0));
return D;
} // split_overlap_ag
/*
** Copy input to _ubufs[0]
*/
void copy_input_to_ubuf(torch::Tensor input, bool chunk) {
at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream();
if (chunk) {
// Copy input to the target ubuf chunk by rank offset
if (input.numel() != _ubufs[0].numel() || input.element_size() != _ubufs[0].element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
CHECK_CUDA(cudaMemcpyAsync(_ubufs[_tp_id].data_ptr(), input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
} else {
if (input.numel() != _ubuf.numel() || input.element_size() != _ubuf.element_size()) {
NVTE_ERROR("input and ubuf size do not match!");
}
CHECK_CUDA(cudaMemcpyAsync(_ubuf.data_ptr(), input.data_ptr(),
input.numel() * input.element_size(), cudaMemcpyDeviceToDevice,
(cudaStream_t)stream_main));
}
}
torch::Tensor get_ubuf_output(int comm_type) {
char *ubuf_wt_ptr = reinterpret_cast<char *>(_ubuf.data_ptr());
COMM_TYPE _comm_type = static_cast<COMM_TYPE>(comm_type);
if (_comm_type != COMM_TYPE::AG && _comm_type != COMM_TYPE::RS) NVTE_ERROR("Invalid comm_type");
if (_comm_type == COMM_TYPE::RS)
ubuf_wt_ptr += _ubuf.numel() / _tp_size * _tp_id * _ubuf.element_size();
int output_c_dim0 = (_comm_type == COMM_TYPE::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options());
}
}; // UbufP2PCommOverlap
} // namespace ubuf
......@@ -5,7 +5,9 @@
************************************************************************/
#include "extensions.h"
#ifdef NVTE_MPI_FOUND
#include "comm_gemm_overlap.h"
#endif // NVTE_MPI_FOUND
void te_gemm(at::Tensor A,
at::Tensor A_scale_inverse,
......@@ -26,7 +28,8 @@ void te_gemm(at::Tensor A,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator
bool use_split_accumulator,
int math_sm_count
) {
using namespace transformer_engine;
auto te_A = makeTransformerEngineTensor(A.data_ptr(),
......@@ -70,6 +73,7 @@ void te_gemm(at::Tensor A,
te_workspace.data(),
accumulate,
use_split_accumulator,
math_sm_count,
at::cuda::getCurrentCUDAStream());
}
......@@ -536,6 +540,67 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
}
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -609,6 +674,61 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
}
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
DType itype = GetTransformerEngineDType(input.scalar_type());
auto mu = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto rsigma = at::empty({static_cast<int64_t>(N)}, at::CUDA(at::kFloat));
auto input_cu = makeTransformerEngineTensor(input);
auto gamma_cu = makeTransformerEngineTensor(weight);
auto beta_cu = makeTransformerEngineTensor(bias);
auto z_cu = makeTransformerEngineTensor(ln_out);
auto mu_cu = makeTransformerEngineTensor(mu);
auto rsigma_cu = makeTransformerEngineTensor(rsigma);
transformer_engine::TensorWrapper workspace, barrier;
// This call populates workspace and barrier tensors with the required config
const auto func = zero_centered_gamma ? nvte_layernorm1p_fwd : nvte_layernorm_fwd;
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
// Fill workspace and barrier
auto workspace_data = allocateSpace(workspace.shape(),
workspace.dtype());
auto barrier_data = allocateSpace(barrier.shape(),
barrier.dtype(),
true);
workspace = makeTransformerEngineTensor(workspace_data.data_ptr(),
workspace.shape(),
workspace.dtype());
barrier = makeTransformerEngineTensor(barrier_data.data_ptr(),
barrier.shape(),
barrier.dtype());
// Actual call to fwd kernel
func(input_cu.data(), gamma_cu.data(), beta_cu.data(), eps, z_cu.data(),
mu_cu.data(), rsigma_cu.data(), at::cuda::getCurrentCUDAStream(),
at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin,
workspace.data(), barrier.data());
return {ln_out, mu, rsigma};
}
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -646,6 +766,29 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
}
void cast_to_fp8_noalloc(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;
size_t N = static_cast<size_t>(input.size(0));
size_t H = static_cast<size_t>(input.size(1));
auto input_cu = makeTransformerEngineTensor(input);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {N, H}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());
nvte_fp8_quantize(input_cu.data(), output_cu.data(),
at::cuda::getCurrentCUDAStream());
return;
}
at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv,
transformer_engine::DType itype,
......@@ -878,6 +1021,17 @@ size_t get_cublasLt_version() {
}
bool userbuf_comm_available() { // TODO(ksivamani) check on python side
#ifdef NVTE_MPI_FOUND
return true;
#else
return false;
#endif
}
void placeholder() {} // TODO(ksivamani) clean this up
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions
m.def("scaled_softmax_forward", &scaled_softmax_forward, "Scaled Softmax FWD");
......@@ -895,8 +1049,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Other granular functions
m.def("layernorm_fwd_fp8", &layernorm_fwd_fp8, "LN FWD FP8");
m.def("layernorm_fwd_fp8_noalloc", &layernorm_fwd_fp8_noalloc, "LN FWD FP8");
m.def("layernorm_bwd", &layernorm_bwd, "LN BWD");
m.def("layernorm_fwd", &layernorm_fwd, "LN FWD");
m.def("layernorm_fwd_noalloc", &layernorm_fwd_noalloc, "LN FWD");
m.def("fused_cast_transpose", &fused_cast_transpose, "Fused Cast + Transpose");
m.def("fused_cast_transpose_bgrad", &fused_cast_transpose_bgrad,
"Fused Cast + Transpose + BGRAD");
......@@ -907,6 +1063,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose");
m.def("cast_to_fp8", &cast_to_fp8, "Cast to FP8");
m.def("cast_to_fp8_noalloc", &cast_to_fp8_noalloc, "Cast to FP8");
m.def("cast_from_fp8", &cast_from_fp8, "Cast from FP8");
m.def("te_gemm", &te_gemm, "CublasLt GEMM");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
......@@ -914,6 +1071,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
m.def("userbuf_comm_available", &userbuf_comm_available, "If userbuf backend is available");
// Data structures
py::class_<transformer_engine::FP8TensorMeta>(m, "FP8TensorMeta")
......@@ -922,6 +1080,31 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def_readwrite("scale_inv", &transformer_engine::FP8TensorMeta::scale_inv)
.def_readwrite("amax_history", &transformer_engine::FP8TensorMeta::amax_history);
#ifdef NVTE_MPI_FOUND
py::enum_<ubuf::UBOverlapAlgo>(m, "UbufOverlapAlgo")
.value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG)
.value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS)
.value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS)
.value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG);
py::class_<ubuf::UbufCommOverlap>(m, "UbufCommOverlap")
.def(py::init<torch::Tensor&, int, int, int, int, int, bool, int>())
.def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap)
.def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs)
.def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output);
py::class_<ubuf::UbufP2PCommOverlap>(m, "UbufP2PCommOverlap")
.def(py::init<torch::Tensor&, int, int, bool, int>())
.def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag)
.def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf)
.def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output);
#else // NVTE_MPI_FOUND
m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations");
m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations");
m.def("UbufP2PCommOverlap", &placeholder, "Dummy function for python side annotations");
#endif // NVTE_MPI_FOUND
py::enum_<transformer_engine::DType>(m, "DType", py::module_local())
.value("kByte", transformer_engine::DType::kByte)
.value("kInt32", transformer_engine::DType::kInt32)
......
......@@ -26,7 +26,8 @@ void te_gemm(at::Tensor A,
at::Tensor workspace,
size_t workspaceSize,
bool accumulate,
bool use_split_accumulator
bool use_split_accumulator,
int math_sm_count
);
......@@ -111,6 +112,19 @@ std::vector<at::Tensor> layernorm_fwd_fp8(const at::Tensor &input,
const bool zero_centered_gamma
);
std::vector<at::Tensor> layernorm_fwd_fp8_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
float eps,
at::Tensor scale,
at::Tensor ln_out,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype,
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor layernorm_fwd_fp8_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -130,6 +144,15 @@ std::vector<at::Tensor> layernorm_fwd(const at::Tensor &input,
const bool zero_centered_gamma
);
std::vector<at::Tensor> layernorm_fwd_noalloc(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
at::Tensor ln_out,
float eps,
const int sm_margin,
const bool zero_centered_gamma
);
at::Tensor layernorm_fwd_inf(const at::Tensor &input,
const at::Tensor &weight,
const at::Tensor &bias,
......@@ -145,6 +168,15 @@ at::Tensor cast_to_fp8(const at::Tensor &input,
);
void cast_to_fp8_noalloc(const at::Tensor &input,
const at::Tensor &scale,
at::Tensor output,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);
at::Tensor cast_from_fp8(const at::Tensor &input,
const at::Tensor &scale_inv,
transformer_engine::DType itype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment