Unverified Commit afa1f1b0 authored by gdengk's avatar gdengk Committed by GitHub
Browse files

Introduce NVSHMEM based communication API for pytorch (#1430)



* add nvshmem based api support
Signed-off-by: default avatargdeng <gdeng@nvidia.com>

* fix lint and license issue
Signed-off-by: default avatargdeng <gdeng@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* remove asset
Signed-off-by: default avatargdeng <gdeng@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the lib
Signed-off-by: default avatargdeng <gdeng@nvidia.com>

* address comments
Signed-off-by: default avatargdeng <gdeng@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatargdeng <gdeng@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent da42e212
......@@ -89,6 +89,19 @@ def setup_pytorch_extension(
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
include_dirs.append(nvshmem_home / "include")
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
......@@ -102,4 +115,6 @@ def setup_pytorch_extension(
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
......@@ -64,6 +64,12 @@ def setup_common_extension() -> CMakeExtension:
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
......
......@@ -96,6 +96,8 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES})
target_include_directories(transformer_engine PUBLIC
"${CMAKE_CURRENT_SOURCE_DIR}/include")
# Configure dependencies
target_link_libraries(transformer_engine PUBLIC
CUDA::cublas
......@@ -114,6 +116,13 @@ if (NVTE_UB_WITH_MPI)
target_compile_definitions(transformer_engine PUBLIC NVTE_UB_WITH_MPI)
endif()
option(NVTE_ENABLE_NVSHMEM "Compile with NVSHMEM library" OFF)
if (NVTE_ENABLE_NVSHMEM)
add_subdirectory(nvshmem_api)
target_link_libraries(transformer_engine PUBLIC nvshmemapi)
target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR})
endif()
# Hack to enable dynamic loading in cuDNN frontend
target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING)
......
......@@ -16,7 +16,9 @@
transformer_engine::is_fp8_dtype*;
*transformer_engine::CommOverlapBase*;
*transformer_engine::CommOverlapP2PBase*;
*transformer_engine::CommOverlapCore*
*transformer_engine::CommOverlapCore*;
*nvshmem_wait_on_stream*;
*nvshmemi_init_thread*
};
local: *;
};
##########################################################################
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
##########################################################################
cmake_minimum_required (VERSION 3.18)
project(nvshmemapi LANGUAGES CXX CUDA)
# Configure dependencies
find_package(CUDAToolkit REQUIRED)
# find_package(MPI REQUIRED)
set(NVSHMEM_HOME "$ENV{NVSHMEM_HOME}" CACHE STRING "Location of NVSHMEM installation")
add_library(nvshmemapi STATIC nvshmem_waitkernel.cu)
set(NVSHMEMAPI_INCLUDE_DIR "${CMAKE_CURRENT_SOURCE_DIR}" PARENT_SCOPE)
target_link_directories(nvshmemapi PUBLIC ${NVSHMEM_HOME}/lib)
target_link_libraries(nvshmemapi PUBLIC -static-libstdc++ nvshmem_device nvshmem_host CUDA::nvml CUDA::cublas CUDA::cuda_driver)
target_include_directories(nvshmemapi PRIVATE
${NVSHMEM_HOME}/include/)
target_include_directories(nvshmemapi PUBLIC
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
"${CMAKE_CURRENT_SOURCE_DIR}")
set_target_properties(nvshmemapi PROPERTIES
CUDA_STANDARD 17
POSITION_INDEPENDENT_CODE ON
CUDA_SEPARABLE_COMPILATION ON)
\ No newline at end of file
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cuda_bf16.h>
#include <nvshmem.h>
#include <cstdio>
#include <cstdlib>
#include <functional>
#include <iostream>
#include <sstream>
#include <string>
#include "../util/logging.h"
#include "nvshmem_waitkernel.h"
__global__ void __launch_bounds__(1)
wait_until_on_stream_and_reset(uint64_t* wait_flag, uint64_t wait_value,
uint64_t signal_reset) {
nvshmem_uint64_wait_until(wait_flag, NVSHMEM_CMP_EQ, wait_value);
*wait_flag = signal_reset;
}
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream) {
uint64_t wait_value = 1;
uint64_t signal_reset = 0;
cudaStream_t cur_stream = stream;
NVTE_CHECK(wait_kind >= WaitKind::KERNEL_WAIT && wait_kind <= WaitKind::STREAM_WAIT,
"Invalid wait kind: ", static_cast<int>(wait_kind));
switch (wait_kind) {
case WaitKind::KERNEL_WAIT:
wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset);
break;
case WaitKind::NVSHMEM_WAIT:
nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
break;
case WaitKind::STREAM_WAIT:
cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value,
CU_STREAM_WAIT_VALUE_GEQ);
cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset,
CU_STREAM_WRITE_VALUE_DEFAULT);
break;
}
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#define TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
#ifdef __cplusplus
#include <cstdint>
extern "C" {
#else
#include <stdint.h>
#endif
/*! \enum WaitKind
* \brief Types of wait operations that can be performed.
*/
enum class WaitKind {
KERNEL_WAIT = 0, /*!< Wait using a CUDA kernel */
NVSHMEM_WAIT = 1, /*!< Wait using NVSHMEM wait operation */
STREAM_WAIT = 2 /*!< Wait using CUDA stream synchronization */
};
/*! \brief Wait on a signal until a certain condition is met.
*
* \param[in] sig_addr The address of the signal to wait on.
* \param[in] wait_kind The kind of wait to perform.
* \param[in] stream The stream to wait on.
*/
void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t stream);
#ifdef __cplusplus
} // extern "C"
#endif
#endif // TRANSFORMER_ENGINE_COMMON_NVSHMEM_WAITKERNEL_H
......@@ -373,6 +373,23 @@ void fused_multi_row_padding(at::Tensor input, at::Tensor output,
std::vector<size_t> input_row_list,
std::vector<size_t> padded_input_row_list);
/***************************************************************************************************
* NVSHMEM APIs
**************************************************************************************************/
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group);
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype);
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal);
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind);
void nvshmem_finalize();
} // namespace nvshmem_api
/***************************************************************************************************
* swizzle
**************************************************************************************************/
......
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#ifdef NVTE_ENABLE_NVSHMEM
#include <nvshmem.h>
#include <nvshmem_api/nvshmem_waitkernel.h>
#include <nvshmemx.h>
#endif
#include <cuda.h>
#include <cuda_fp8.h>
#include <torch/cuda.h>
#include <torch/extension.h>
namespace nvshmem_api {
void init_nvshmem_backend(c10d::ProcessGroup *process_group) {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmemx_init_attr_t attr = {};
nvshmemx_uniqueid_t id = {};
int my_rank = process_group->getRank();
int num_ranks = process_group->getSize();
if (my_rank == 0) {
nvshmemx_get_uniqueid(&id);
}
auto backend_is_nccl = (process_group->getBackendType() == c10d::ProcessGroup::BackendType::NCCL);
NVTE_CHECK(backend_is_nccl, "Currently only support NCCL boostrap for NVSHMEM");
auto datatensor =
torch::from_blob(reinterpret_cast<void *>(&id),
{static_cast<int64_t>(sizeof(nvshmemx_uniqueid_t) / sizeof(uint8_t))},
at::device(torch::kCPU).dtype(torch::kUInt8));
auto datatmp = (backend_is_nccl) ? datatensor.cuda() : datatensor;
c10d::BroadcastOptions bcast_opts;
bcast_opts.rootRank = 0;
std::vector<torch::Tensor> datachunk = {datatmp};
auto work = process_group->broadcast(datachunk, bcast_opts);
work->wait();
if (backend_is_nccl) {
datatensor.copy_(datatmp.cpu());
datatmp = torch::Tensor();
}
nvshmemx_set_attr_uniqueid_args(my_rank, num_ranks, &id, &attr);
nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr);
NVTE_CHECK(my_rank == nvshmem_my_pe(), "my_rank: ", my_rank,
" != nvshmem_my_pe(): ", nvshmem_my_pe());
NVTE_CHECK(num_ranks == nvshmem_n_pes(), "num_ranks: ", num_ranks,
" != nvshmem_n_pes(): ", nvshmem_n_pes());
#else
NVTE_ERROR("Internal TE error: init_nvshmem_backend cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_wait_on_current_stream(torch::Tensor signal, const std::string &wait_kind) {
#ifdef NVTE_ENABLE_NVSHMEM
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
cudaStream_t cur_stream = (cudaStream_t)at::cuda::getCurrentCUDAStream();
WaitKind wait_kind_enum = WaitKind::STREAM_WAIT;
if (wait_kind == "kernel") {
wait_kind_enum = WaitKind::KERNEL_WAIT;
} else if (wait_kind == "nvshmem") {
wait_kind_enum = WaitKind::NVSHMEM_WAIT;
} else if (wait_kind == "stream") {
wait_kind_enum = WaitKind::STREAM_WAIT;
} else {
NVTE_ERROR("Invalid wait kind: ", wait_kind);
}
nvshmem_wait_on_stream(sig_addr, wait_kind_enum, cur_stream);
#else
NVTE_ERROR(
"Internal TE error: nvshmem_wait_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
torch::Tensor create_nvshmem_tensor(const std::vector<int64_t> &shape, c10::ScalarType dtype) {
#ifdef NVTE_ENABLE_NVSHMEM
auto option_gpu =
at::TensorOptions().dtype(dtype).device(at::kCUDA).device_index(c10::cuda::current_device());
auto size = torch::elementSize(dtype) *
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>());
return at::from_blob(
nvshmem_malloc(size), shape, [](void *ptr) { nvshmem_free(ptr); }, option_gpu);
#else
NVTE_ERROR("Internal TE error: create_nvshmem_tensor cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_send_on_current_stream(torch::Tensor src, torch::Tensor dst, int peer,
torch::Tensor signal) {
#ifdef NVTE_ENABLE_NVSHMEM
void *src_ptr = reinterpret_cast<void *>(src.data_ptr());
void *dst_ptr = reinterpret_cast<void *>(dst.data_ptr());
uint64_t *sig_addr = reinterpret_cast<uint64_t *>(signal.data_ptr());
auto nelement = src.numel() * src.element_size();
uint64_t sigval = 1;
at::cuda::CUDAStream cur_stream = at::cuda::getCurrentCUDAStream();
nvshmemx_putmem_signal_on_stream(dst_ptr, src_ptr, nelement, sig_addr, sigval, NVSHMEM_SIGNAL_SET,
peer, (cudaStream_t)cur_stream);
#else
NVTE_ERROR(
"Internal TE error: nvshmem_send_on_current_stream cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
void nvshmem_finalize() {
#ifdef NVTE_ENABLE_NVSHMEM
nvshmem_finalize();
#else
NVTE_ERROR("Internal TE error: nvshmem_finalize cannot be initialized with valid PyTorch ",
"distributed process groups when TE is compiled with NVTE_ENABLE_NVSHMEM=1!");
#endif
}
} // namespace nvshmem_api
......@@ -234,6 +234,23 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Generate partitioned indices for inputs in THD format",
py::call_guard<py::gil_scoped_release>());
// nvshmem functions
m.def("init_nvshmem_backend", &nvshmem_api::init_nvshmem_backend,
"Initialize nvshmem backend with Pytorch distributed process groups",
py::call_guard<py::gil_scoped_release>());
m.def("create_nvshmem_tensor", &nvshmem_api::create_nvshmem_tensor,
"Create a tensor in NVSHMEM shared memory", py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_send_on_current_stream", &nvshmem_api::nvshmem_send_on_current_stream,
"Asynchronously send tensor data to a remote PE using NVSHMEM on the current CUDA stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_wait_on_current_stream", &nvshmem_api::nvshmem_wait_on_current_stream,
"Wait for a signal value to be updated by a remote PE using NVSHMEM on the current CUDA "
"stream",
py::call_guard<py::gil_scoped_release>());
m.def("nvshmem_finalize", &nvshmem_api::nvshmem_finalize,
"Clean up and finalize the NVSHMEM communication backend and free associated resources",
py::call_guard<py::gil_scoped_release>());
// multi-tensor functions
m.def("multi_tensor_scale", &multi_tensor_scale_cuda,
"Fused overflow check + scale for a list of contiguous tensors",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment