Unverified Commit b6ce310b authored by Shangyan Zhou's avatar Shangyan Zhou Committed by GitHub
Browse files

third-party: Improvements to NVSHMEM Integration

third-party: Improvements to NVSHMEM Integration
parents 146b013d c5d22023
...@@ -146,20 +146,26 @@ template <bool kAlwaysDoPostSend> ...@@ -146,20 +146,26 @@ template <bool kAlwaysDoPostSend>
__device__ static __forceinline__ __device__ static __forceinline__
void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx, void ibgda_submit_requests(nvshmemi_ibgda_device_qp_t *qp, uint64_t base_wqe_idx,
uint32_t num_wqes, int message_idx = 0) { uint32_t num_wqes, int message_idx = 0) {
auto state = ibgda_get_state();
nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars; nvshmemi_ibgda_device_qp_management_t *mvars = &qp->mvars;
uint64_t new_wqe_idx = base_wqe_idx + num_wqes; uint64_t new_wqe_idx = base_wqe_idx + num_wqes;
// WQE writes must be finished first // WQE writes must be finished first
__threadfence(); __threadfence();
unsigned long long int *ready_idx =
(unsigned long long int *)(state->use_async_postsend ? qp->tx_wq.prod_idx
: &mvars->tx_wq.ready_head);
// Wait for prior WQE slots to be filled first // Wait for prior WQE slots to be filled first
auto *ready_idx = reinterpret_cast<unsigned long long int*>(&mvars->tx_wq.ready_head);
while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx); while (atomicCAS(ready_idx, base_wqe_idx, new_wqe_idx) != base_wqe_idx);
// Always post, not in batch // Always post, not in batch
constexpr int kNumRequestInBatch = 4; if (!state->use_async_postsend) {
if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0) constexpr int kNumRequestInBatch = 4;
ibgda_post_send(qp, new_wqe_idx); if (kAlwaysDoPostSend or (message_idx + 1) % kNumRequestInBatch == 0)
ibgda_post_send(qp, new_wqe_idx);
}
} }
__device__ static __forceinline__ void __device__ static __forceinline__ void
...@@ -488,7 +494,8 @@ ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) { ...@@ -488,7 +494,8 @@ ibgda_poll_cq(nvshmemi_ibgda_device_cq_t *cq, uint64_t idx) {
__device__ static __forceinline__ void __device__ static __forceinline__ void
nvshmemi_ibgda_quiet(int dst_pe, int qp_id) { nvshmemi_ibgda_quiet(int dst_pe, int qp_id) {
auto qp = ibgda_get_rc(dst_pe, qp_id); auto qp = ibgda_get_rc(dst_pe, qp_id);
uint64_t prod_idx = ld_na_relaxed(qp->tx_wq.prod_idx); auto state = ibgda_get_state();
uint64_t prod_idx = state->use_async_postsend ? ld_na_relaxed(qp->tx_wq.prod_idx) : ld_na_relaxed(&qp->mvars.tx_wq.ready_head);
ibgda_poll_cq(qp->tx_wq.cq, prod_idx); ibgda_poll_cq(qp->tx_wq.cq, prod_idx);
} }
......
...@@ -7,6 +7,7 @@ ...@@ -7,6 +7,7 @@
#include "utils.cuh" #include "utils.cuh"
#ifndef DISABLE_NVSHMEM #ifndef DISABLE_NVSHMEM
#include "nvshmem.h"
#include "ibgda_device.cuh" #include "ibgda_device.cuh"
#endif #endif
......
...@@ -83,7 +83,6 @@ class Buffer: ...@@ -83,7 +83,6 @@ class Buffer:
assert num_qps_per_rank > 0 assert num_qps_per_rank > 0
os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1' os.environ['NVSHMEM_DISABLE_P2P'] = '0' if allow_nvlink_for_low_latency_mode else '1'
os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1' os.environ['NVSHMEM_IB_ENABLE_IBGDA'] = '1'
os.environ['NVSHMEM_IBGDA_NIC_HANDLER'] = 'gpu'
os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}' os.environ['NVSHMEM_IBGDA_NUM_RC_PER_PE'] = f'{num_qps_per_rank}'
# Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check # Make sure QP depth is always larger than the number of on-flight WRs, so that we can skip WQ slot check
os.environ['NVSHMEM_QP_DEPTH'] = '1024' os.environ['NVSHMEM_QP_DEPTH'] = '1024'
......
import os import os
import subprocess import subprocess
import setuptools import setuptools
import importlib
import importlib.resources
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Wheel specific: The wheels only include the soname of the host library (libnvshmem_host.so.X)
def get_nvshmem_host_lib_name():
for path in importlib.resources.files('nvidia.nvshmem').iterdir():
for file in path.rglob('libnvshmem_host.so.*'):
return file.name
raise ModuleNotFoundError('libnvshmem_host.so not found')
if __name__ == '__main__': if __name__ == '__main__':
disable_nvshmem = False
nvshmem_dir = os.getenv('NVSHMEM_DIR', None) nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
disable_nvshmem = nvshmem_dir is None nvshmem_host_lib = 'libnvshmem_host.so'
if disable_nvshmem: if nvshmem_dir is None:
print('Warning: `NVSHMEM_DIR` is not specified, all internode and low-latency features are disabled\n') try:
nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0]
nvshmem_host_lib = get_nvshmem_host_lib_name()
import nvidia.nvshmem as nvshmem
except (ModuleNotFoundError, AttributeError, IndexError):
print('Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n')
disable_nvshmem = True
else: else:
assert os.path.exists(nvshmem_dir), f'Failed to find NVSHMEM: {nvshmem_dir}' disable_nvshmem = False
if not disable_nvshmem:
assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}'
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes'] '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
...@@ -29,8 +47,8 @@ if __name__ == '__main__': ...@@ -29,8 +47,8 @@ if __name__ == '__main__':
sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu']) sources.extend(['csrc/kernels/internode.cu', 'csrc/kernels/internode_ll.cu'])
include_dirs.extend([f'{nvshmem_dir}/include']) include_dirs.extend([f'{nvshmem_dir}/include'])
library_dirs.extend([f'{nvshmem_dir}/lib']) library_dirs.extend([f'{nvshmem_dir}/lib'])
nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem']) nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device'])
extra_link_args.extend(['-l:libnvshmem.a', '-l:nvshmem_bootstrap_uid.so', f'-Wl,-rpath,{nvshmem_dir}/lib']) extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib'])
if int(os.getenv('DISABLE_SM90_FEATURES', 0)): if int(os.getenv('DISABLE_SM90_FEATURES', 0)):
# Prefer A100 # Prefer A100
......
...@@ -14,25 +14,30 @@ Hardware requirements: ...@@ -14,25 +14,30 @@ Hardware requirements:
- InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/) - InfiniBand GPUDirect Async (IBGDA) support, see [IBGDA Overview](https://developer.nvidia.com/blog/improving-network-performance-of-hpc-systems-using-nvidia-magnum-io-nvshmem-and-gpudirect-async/)
- For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements) - For more detailed requirements, see [NVSHMEM Hardware Specifications](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html#hardware-requirements)
Software requirements:
- NVSHMEM v3.3.9 or later
## Installation procedure ## Installation procedure
### 1. Acquiring NVSHMEM source code ### 1. Install NVSHMEM binaries
Download NVSHMEM source code from the [NVIDIA NVSHMEM OPEN SOURCE PACKAGES](https://developer.nvidia.com/downloads/assets/secure/nvshmem/nvshmem_src_cuda12-all). NVSHMEM 3.3.9 binaries are available in several formats:
- Tarballs for [x86_64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-x86_64/libnvshmem-linux-x86_64-3.3.9_cuda12-archive.tar.xz) and [aarch64](https://developer.download.nvidia.com/compute/nvshmem/redist/libnvshmem/linux-sbsa/libnvshmem-linux-sbsa-3.3.9_cuda12-archive.tar.xz)
- RPM and deb packages: instructions can be found on the [NVHSMEM installer page](https://developer.nvidia.com/nvshmem-downloads?target_os=Linux)
- Conda packages through conda-forge
- pip wheels through PyPI: `pip install nvidia-nvshmem-cu12`
DeepEP is compatible with upstream NVSHMEM 3.3.9 and later.
### 2. [Optional] apply our custom patch
**NOTE: After NVSHMEM v3.3.9, it is no longer necessary to apply our patch to achieve optimal performance.** ### 2. Enable NVSHMEM IBGDA support
Navigate to your NVSHMEM source directory and apply our provided patch: NVSHMEM Supports two modes with different requirements. Either of the following methods can be used to enable IBGDA support.
```bash #### 2.1 Configure NVIDIA driver
git apply /path/to/deep_ep/dir/third-party/nvshmem.patch
```
### 3. Configure NVIDIA driver (required by inter-node communication) This configuration enables traditional IBGDA support.
Enable IBGDA by modifying `/etc/modprobe.d/nvidia.conf`: Modify `/etc/modprobe.d/nvidia.conf`:
```bash ```bash
options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;" options nvidia NVreg_EnableStreamMemOPs=1 NVreg_RegistryDwords="PeerMappingOverride=1;"
...@@ -45,38 +50,19 @@ sudo update-initramfs -u ...@@ -45,38 +50,19 @@ sudo update-initramfs -u
sudo reboot sudo reboot
``` ```
For more detailed configurations, please refer to the [NVSHMEM Installation Guide](https://docs.nvidia.com/nvshmem/release-notes-install-guide/install-guide/abstract.html). #### 2.2 Install GDRCopy and load the gdrdrv kernel module
### 4. Build and installation This configuration enables IBGDA through asynchronous post-send operations assisted by the CPU. More information about CPU-assisted IBGDA can be found in [this blog](https://developer.nvidia.com/blog/enhancing-application-portability-and-compatibility-across-new-platforms-using-nvidia-magnum-io-nvshmem-3-0/#cpu-assisted_infiniband_gpu_direct_async%C2%A0).
It comes with a small performance penalty, but can be used when modifying the driver regkeys is not an option.
DeepEP uses NVLink for intra-node communication and IBGDA for inter-node communication. All the other features are disabled to reduce the dependencies. Download GDRCopy
GDRCopy is available as prebuilt deb and rpm packages [here](https://developer.download.nvidia.com/compute/redist/gdrcopy/). or as source code on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy).
```bash Install GDRCopy following the instructions on the [GDRCopy github repository](https://github.com/NVIDIA/gdrcopy?tab=readme-ov-file#build-and-installation).
export CUDA_HOME=/path/to/cuda
# disable all features except IBGDA
export NVSHMEM_IBGDA_SUPPORT=1
export NVSHMEM_SHMEM_SUPPORT=0
export NVSHMEM_UCX_SUPPORT=0
export NVSHMEM_USE_NCCL=0
export NVSHMEM_PMIX_SUPPORT=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
export NVSHMEM_USE_GDRCOPY=0
export NVSHMEM_IBRC_SUPPORT=0
export NVSHMEM_BUILD_TESTS=0
export NVSHMEM_BUILD_EXAMPLES=0
export NVSHMEM_MPI_SUPPORT=0
export NVSHMEM_BUILD_HYDRA_LAUNCHER=0
export NVSHMEM_BUILD_TXZ_PACKAGE=0
export NVSHMEM_TIMEOUT_DEVICE_POLLING=0
cmake -G Ninja -S . -B build -DCMAKE_INSTALL_PREFIX=/path/to/your/dir/to/install
cmake --build build/ --target install
```
## Post-installation configuration ## Post-installation configuration
Set environment variables in your shell configuration: When not installing NVSHMEM from RPM or deb packages, set the following environment variables in your shell configuration:
```bash ```bash
export NVSHMEM_DIR=/path/to/your/dir/to/install # Use for DeepEP installation export NVSHMEM_DIR=/path/to/your/dir/to/install # Use for DeepEP installation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment