Unverified Commit a8c83f89 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

Parallel build with limited resource (#987)



* add parallel build without pyproject
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 87bfc348
......@@ -23,6 +23,7 @@ from .utils import (
found_ninja,
get_frameworks,
cuda_path,
get_max_jobs_for_parallel_build,
)
......@@ -60,8 +61,6 @@ class CMakeExtension(setuptools.Extension):
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
configure_command += self.cmake_flags
if found_ninja():
configure_command.append("-GNinja")
import pybind11
......@@ -73,6 +72,14 @@ class CMakeExtension(setuptools.Extension):
build_command = [_cmake_bin, "--build", build_dir]
install_command = [_cmake_bin, "--install", build_dir]
# Check whether parallel build is restricted
max_jobs = get_max_jobs_for_parallel_build()
if found_ninja():
configure_command.append("-GNinja")
build_command.append("--parallel")
if max_jobs > 0:
build_command.append(str(max_jobs))
# Run CMake commands
for command in [configure_command, build_command, install_command]:
print(f"Running command {' '.join(command)}")
......
......@@ -28,6 +28,28 @@ def debug_build_enabled() -> bool:
return False
@functools.lru_cache(maxsize=None)
def get_max_jobs_for_parallel_build() -> int:
"""Number of parallel jobs for Nina build"""
# Default: maximum parallel jobs
num_jobs = 0
# Check environment variable
if os.getenv("NVTE_MAX_BUILD_JOBS"):
num_jobs = int(os.getenv("NVTE_MAX_BUILD_JOBS"))
elif os.getenv("MAX_JOBS"):
num_jobs = int(os.getenv("MAX_JOBS"))
# Check command-line arguments
for arg in sys.argv.copy():
if arg.startswith("--parallel="):
num_jobs = int(arg.replace("--parallel=", ""))
sys.argv.remove(arg)
return num_jobs
def all_files_in_dir(path, name_extension=None):
all_files = []
for dirname, _, names in os.walk(path):
......
......@@ -14,7 +14,6 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads 4")
if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment