Unverified Commit f9dd37f7 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[PyTorch] Runtime lookup for CUDA Driver API calls in Userbuffers (#970)



* removed libcuda.so link at compile time for TE/PyTorch extension
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* linting fixes
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* updated get_symbol() in TE/common/cuda_utils.h to new impl based on cudaGetDriverEntryPoint
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix duplicate quotation
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 687697a7
......@@ -74,10 +74,9 @@ def setup_pytorch_extension(
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# Libraries -- PyTorch CUDAExtension links to libcudart.so but not to libcuda.so
cuda_home, _ = cuda_path()
library_dirs = [cuda_home / "compat" / "lib"]
libraries = ["cuda"]
# Libraries
library_dirs = []
libraries = []
if os.getenv("UB_MPI_BOOTSTRAP"):
assert (
os.getenv("MPI_HOME") is not None
......
......@@ -93,7 +93,14 @@ Library &cuda_driver_lib() {
namespace cuda_driver {
void *get_symbol(const char *symbol) { return cuda_driver_lib().get_symbol(symbol); }
void *get_symbol(const char *symbol) {
void *entry_point;
cudaDriverEntryPointQueryResult driver_result;
NVTE_CHECK_CUDA(cudaGetDriverEntryPoint(symbol, &entry_point, cudaEnableDefault, &driver_result));
NVTE_CHECK(driver_result == cudaDriverEntryPointSuccess,
"Could not find CUDA driver entry point for ", symbol);
return entry_point;
}
} // namespace cuda_driver
......
......@@ -38,6 +38,8 @@ from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.module import GroupedLinear
from transformer_engine.pytorch.module import initialize_ub
from transformer_engine.pytorch.module import destroy_ub
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
......
......@@ -16,7 +16,10 @@
#include <chrono>
#include <iostream>
#include <map>
#include <utility>
#include "../util/cuda_driver.h"
#include "ipcsocket.h"
#include "userbuffers.h"
......@@ -44,17 +47,6 @@ int stringCmp(const void *a, const void *b) { return strcmp((const char *)a, (co
} \
} while (0)
#define CUCHECK(cmd) \
do { \
CUresult retval = cmd; \
if (retval != CUDA_SUCCESS) { \
const char *error_string; \
cuGetErrorString(retval, &error_string); \
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, error_string); \
exit(EXIT_FAILURE); \
} \
} while (0);
#define NVTE_UB_ERROR(x) \
do { \
throw std::runtime_error(std::string(__FILE__ ":") + std::to_string(__LINE__) + \
......@@ -96,7 +88,7 @@ int create_communicator_grouped2(
int numnodes, std::function<void(void **, void *, size_t, ExtComm)> ext_alloc_copy_allgather,
std::function<void(ExtComm)> ext_barrier, std::function<void(void *)> ext_free, int pipegpus,
int pipenodes, int tensorgpus, int tensornodes) {
*comm = reinterpret_cast<communicator *>(malloc(sizeof(communicator)));
*comm = new communicator();
(*comm)->comm_world = EXT_COMM_WORLD;
(*comm)->_alloc_copy_allgather = ext_alloc_copy_allgather;
......@@ -211,7 +203,9 @@ int create_communicator_grouped2(
mcProp.size = (*comm)->mc_maxsize;
mcProp.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
CUCHECK(cuMulticastGetGranularity(&gran, &mcProp, CU_MULTICAST_GRANULARITY_RECOMMENDED));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMulticastGetGranularity, &gran, &mcProp,
static_cast<CUmemAllocationGranularity_flags>(CU_MULTICAST_GRANULARITY_RECOMMENDED));
mc_maxsize = ((mc_maxsize + gran - 1) / gran) * gran;
mcProp.size = mc_maxsize;
(*comm)->mc_maxsize = mc_maxsize;
......@@ -230,9 +224,12 @@ int create_communicator_grouped2(
(*comm)->_barrier((*comm)->comm_world);
if ((*comm)->ar2_nvrank == 0) {
CUCHECK(cuMulticastCreate(&(*comm)->mc_handle, &mcProp));
CUCHECK(cuMemExportToShareableHandle(&fd, (*comm)->mc_handle,
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastCreate, &(*comm)->mc_handle, &mcProp);
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemExportToShareableHandle, reinterpret_cast<void *>(&fd), (*comm)->mc_handle,
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR),
(uint64_t)0);
for (int p = 1; p < (*comm)->ar2_nvsize; p++) {
(*comm)->_barrier((*comm)->comm_intra);
NCCLCHECKGOTO(ncclIpcSocketSendFd(&ipcSock, fd, p, (uint64_t)opId), ret, error);
......@@ -242,23 +239,28 @@ int create_communicator_grouped2(
NCCLCHECKGOTO(ncclIpcSocketRecvFd(&ipcSock, &fd), ret, error);
for (int i = 0; i < (*comm)->ar2_nvsize - (*comm)->ar2_nvrank - 1; i++)
(*comm)->_barrier((*comm)->comm_intra);
CUCHECK(cuMemImportFromShareableHandle(&(*comm)->mc_handle, reinterpret_cast<void *>(fd),
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &(*comm)->mc_handle, reinterpret_cast<void *>(fd),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
}
error:
NCCLCHECK(ncclIpcSocketClose(&ipcSock));
close(fd);
CUCHECK(cuMulticastAddDevice((*comm)->mc_handle, (*comm)->mydev));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastAddDevice, (*comm)->mc_handle,
(CUdeviceptr)(*comm)->mydev);
CUdeviceptr mc_va;
CUCHECK(cuMemAddressReserve(&mc_va, mc_maxsize, 0, 0U, 0));
CUCHECK(cuMemMap(mc_va, mc_maxsize, 0, (*comm)->mc_handle, 0));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &mc_va, mc_maxsize, (size_t)0, (CUdeviceptr)0U,
(uint64_t)0);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemMap, mc_va, mc_maxsize, (size_t)0, (*comm)->mc_handle,
(uint64_t)0);
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
accessDesc.location.id = (*comm)->mydev;
accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CUCHECK(cuMemSetAccess(mc_va, mc_maxsize, &accessDesc, 1));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemSetAccess, mc_va, mc_maxsize,
const_cast<CUmemAccessDesc *>(&accessDesc), (size_t)1);
(*comm)->mc_baseptr = reinterpret_cast<void *>(mc_va);
(*comm)->_barrier((*comm)->comm_world);
......@@ -402,10 +404,11 @@ int create_communicator_mpi(communicator **comm) {
void destroy_communicator(communicator *comm) {
for (int hndl = 0; hndl < comm->free_region; hndl++) {
if (comm->mem_dealloc[hndl]) {
cuMemAddressFree(reinterpret_cast<CUdeviceptr>(comm->ucbase_ptr[hndl]),
comm->mem_size[hndl] * comm->nvsize);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree,
reinterpret_cast<CUdeviceptr>(comm->ucbase_ptr[hndl]),
comm->mem_size[hndl] * comm->nvsize);
for (int rank = 0; rank < comm->nvsize; rank++) {
cuMemRelease(comm->uchandles[hndl][rank]);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->uchandles[hndl][rank]);
}
free(reinterpret_cast<void *>(comm->uchandles[hndl]));
} else {
......@@ -424,14 +427,15 @@ void destroy_communicator(communicator *comm) {
cudaFree(reinterpret_cast<void *>(comm->recv_id));
cudaFree(reinterpret_cast<void *>(comm->send_id));
if (comm->use_mc) {
cuMemAddressFree(reinterpret_cast<CUdeviceptr>(comm->mc_baseptr), comm->mc_maxsize);
cuMemRelease(comm->mc_handle);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressFree, reinterpret_cast<CUdeviceptr>(comm->mc_baseptr),
comm->mc_maxsize);
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemRelease, comm->mc_handle);
}
if (comm->mem_dealloc[0]) {
cudaFree(comm->gpu_ptrs);
}
free(comm->fifo);
free(comm);
delete comm;
}
void destroy_communicator_mpi(communicator *comm) {
......@@ -466,7 +470,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; // CU_MEM_HANDLE_TYPE_FABRIC;
size_t granularity = 0;
CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemGetAllocationGranularity, &granularity, &prop,
static_cast<CUmemAllocationGranularity_flags>(CU_MULTICAST_GRANULARITY_MINIMUM));
// MPI_Allreduce MAX of granularity check
aligned_size = (bytes + granularity - 1) / granularity * granularity;
......@@ -475,18 +481,24 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
mcProp.numDevices = nranks;
mcProp.size = aligned_size;
mcProp.handleTypes = prop.requestedHandleTypes;
CUCHECK(cuMulticastGetGranularity(&granularity, &mcProp, CU_MULTICAST_GRANULARITY_MINIMUM));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMulticastGetGranularity, &granularity, &mcProp,
static_cast<CUmemAllocationGranularity_flags>(CU_MULTICAST_GRANULARITY_MINIMUM));
aligned_size = (aligned_size + granularity - 1) / granularity * granularity;
}
prop.location.id = comm->mydev;
comm->uchandles[hndl] = reinterpret_cast<CUmemGenericAllocationHandle *>(
malloc(nranks * sizeof(CUmemGenericAllocationHandle)));
CUCHECK(cuMemCreate(&(comm->uchandles[hndl][myrank]), aligned_size, &prop, 0));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemCreate, &(comm->uchandles[hndl][myrank]), aligned_size, &prop,
(uint64_t)0);
int *peerfd = reinterpret_cast<int *>(malloc(nranks * sizeof(int)));
CUCHECK(cuMemExportToShareableHandle(&peerfd[myrank], comm->uchandles[hndl][myrank],
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR, 0 /*flags*/));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemExportToShareableHandle, reinterpret_cast<void *>(&peerfd[myrank]),
comm->uchandles[hndl][myrank],
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR),
(uint64_t)0);
volatile uint32_t abortFlag = 0;
struct ncclIpcSocket ipcSock = {0};
......@@ -512,13 +524,15 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
for (int p = 0; p < nranks; p++) {
if (p != myrank)
CUCHECK(cuMemImportFromShareableHandle(&comm->uchandles[hndl][p],
reinterpret_cast<void *>(peerfd[p]),
CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
NVTE_CALL_CHECK_CUDA_DRIVER(
cuMemImportFromShareableHandle, &comm->uchandles[hndl][p],
reinterpret_cast<void *>(peerfd[p]),
static_cast<CUmemAllocationHandleType>(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR));
close(peerfd[p]);
}
CUdeviceptr ptr;
CUCHECK(cuMemAddressReserve(&ptr, aligned_size * nranks, 0, 0, 0));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemAddressReserve, &ptr, (size_t)(aligned_size * nranks),
(size_t)0, (CUdeviceptr)0, (uint64_t)0);
comm->ucbase_ptr[hndl] = reinterpret_cast<void *>(ptr);
CUmemAccessDesc accessDesc = {};
accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
......@@ -526,8 +540,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
accessDesc.location.id = comm->mydev;
for (int i = 0; i < nranks; i++) {
CUCHECK(cuMemMap(ptr + (aligned_size * i), aligned_size, 0, comm->uchandles[hndl][i], 0));
remptrs[i] = reinterpret_cast<void *>(ptr + (aligned_size * i));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemMap, reinterpret_cast<CUdeviceptr>(remptrs[i]), aligned_size,
(size_t)0, comm->uchandles[hndl][i], (uint64_t)0);
if (i == comm->nvrank) {
if (hndl)
*gpubuff = remptrs[i];
......@@ -536,7 +551,8 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
}
comm->peer_ptr[hndl][i] = remptrs[i];
}
CUCHECK(cuMemSetAccess(ptr, aligned_size * nranks, &accessDesc, 1));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMemSetAccess, ptr, (size_t)(aligned_size * nranks),
const_cast<CUmemAccessDesc *>(&accessDesc), (size_t)1);
if (hndl == 0) CUDACHECK(cudaMemset(comm->gpu_ptrs, 0, aligned_size));
CUDACHECK(
......@@ -547,8 +563,9 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *
comm->memflags[hndl] = UB_MEM_UC_CONTIG | UB_MEM_ALLOCATED;
if (comm->use_mc && comm->mc_maxsize >= comm->mc_offset + aligned_size) {
CUCHECK(cuMulticastBindMem(comm->mc_handle, comm->mc_offset, comm->uchandles[hndl][myrank],
0 /*memOffset*/, aligned_size, 0));
NVTE_CALL_CHECK_CUDA_DRIVER(cuMulticastBindMem, comm->mc_handle, comm->mc_offset,
comm->uchandles[hndl][myrank], (size_t)0 /*memOffset*/,
aligned_size, (uint64_t)0);
comm->memflags[hndl] |= UB_MEM_MC_CREATED;
comm->mc_ptr[hndl] = reinterpret_cast<char *>(comm->mc_baseptr) + comm->mc_offset;
comm->mc_offset += aligned_size;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment