Commit 2a873392 authored by Seth Howell's avatar Seth Howell
Browse files

setup.py: Add logic for detecting library locations from NVSHMEM wheels.


Signed-off-by: default avatarSeth Howell <sethh@nvidia.com>
parent 903711c6
import os import os
import subprocess import subprocess
import setuptools import setuptools
import importlib
import importlib.resources
from torch.utils.cpp_extension import BuildExtension, CUDAExtension from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# Wheel specific: The wheels only include the soname of the host library (libnvshmem_host.so.X)
def get_nvshmem_host_lib_name():
for path in importlib.resources.files('nvidia.nvshmem').iterdir():
for file in path.rglob('libnvshmem_host.so.*'):
return file.name
raise ModuleNotFoundError('libnvshmem_host.so not found')
if __name__ == '__main__': if __name__ == '__main__':
disable_nvshmem = False
nvshmem_dir = os.getenv('NVSHMEM_DIR', None) nvshmem_dir = os.getenv('NVSHMEM_DIR', None)
disable_nvshmem = nvshmem_dir is None nvshmem_host_lib = 'libnvshmem_host.so'
if disable_nvshmem: if nvshmem_dir is None:
print('Warning: `NVSHMEM_DIR` is not specified, all internode and low-latency features are disabled\n') try:
nvshmem_dir = importlib.util.find_spec("nvidia.nvshmem").submodule_search_locations[0]
nvshmem_host_lib = get_nvshmem_host_lib_name()
import nvidia.nvshmem as nvshmem
except (ModuleNotFoundError, AttributeError, IndexError):
print('Warning: `NVSHMEM_DIR` is not specified, and the NVSHMEM module is not installed. All internode and low-latency features are disabled\n')
disable_nvshmem = True
else: else:
assert os.path.exists(nvshmem_dir), f'Failed to find NVSHMEM: {nvshmem_dir}' disable_nvshmem = False
if not disable_nvshmem:
assert os.path.exists(nvshmem_dir), f'The specified NVSHMEM directory does not exist: {nvshmem_dir}'
cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable', cxx_flags = ['-O3', '-Wno-deprecated-declarations', '-Wno-unused-variable',
'-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes'] '-Wno-sign-compare', '-Wno-reorder', '-Wno-attributes']
...@@ -30,7 +49,7 @@ if __name__ == '__main__': ...@@ -30,7 +49,7 @@ if __name__ == '__main__':
include_dirs.extend([f'{nvshmem_dir}/include']) include_dirs.extend([f'{nvshmem_dir}/include'])
library_dirs.extend([f'{nvshmem_dir}/lib']) library_dirs.extend([f'{nvshmem_dir}/lib'])
nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device']) nvcc_dlink.extend(['-dlink', f'-L{nvshmem_dir}/lib', '-lnvshmem_device'])
extra_link_args.extend(['-l:libnvshmem_host.so', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib']) extra_link_args.extend([f'-l:{nvshmem_host_lib}', '-l:libnvshmem_device.a', f'-Wl,-rpath,{nvshmem_dir}/lib'])
if int(os.getenv('DISABLE_SM90_FEATURES', 0)): if int(os.getenv('DISABLE_SM90_FEATURES', 0)):
# Prefer A100 # Prefer A100
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment