Unverified Commit ed3fb6b2 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

TE with threading build (#1092)



* added threading build back

* integrating threading for pytorch and paddle extensions

* added messages

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 44c8924f
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
from pathlib import Path from pathlib import Path
import setuptools import setuptools
import os
from .utils import cuda_version from .utils import cuda_version
...@@ -62,7 +63,7 @@ def setup_paddle_extension( ...@@ -62,7 +63,7 @@ def setup_paddle_extension(
print("Could not determine CUDA Toolkit version") print("Could not determine CUDA Toolkit version")
else: else:
if version >= (11, 2): if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"]) nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")])
if version >= (11, 0): if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8): if version >= (11, 8):
......
...@@ -68,7 +68,7 @@ def setup_pytorch_extension( ...@@ -68,7 +68,7 @@ def setup_pytorch_extension(
print("Could not determine CUDA Toolkit version") print("Could not determine CUDA Toolkit version")
else: else:
if version >= (11, 2): if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"]) nvcc_flags.extend(["--threads", os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1")])
if version >= (11, 0): if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"]) nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8): if version >= (11, 8):
......
...@@ -37,8 +37,8 @@ def get_max_jobs_for_parallel_build() -> int: ...@@ -37,8 +37,8 @@ def get_max_jobs_for_parallel_build() -> int:
num_jobs = 0 num_jobs = 0
# Check environment variable # Check environment variable
if os.getenv("NVTE_MAX_BUILD_JOBS"): if os.getenv("NVTE_BUILD_MAX_JOBS"):
num_jobs = int(os.getenv("NVTE_MAX_BUILD_JOBS")) num_jobs = int(os.getenv("NVTE_BUILD_MAX_JOBS"))
elif os.getenv("MAX_JOBS"): elif os.getenv("MAX_JOBS"):
num_jobs = int(os.getenv("MAX_JOBS")) num_jobs = int(os.getenv("MAX_JOBS"))
......
...@@ -14,6 +14,22 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON) ...@@ -14,6 +14,22 @@ set(CMAKE_CUDA_STANDARD_REQUIRED ON)
project(transformer_engine LANGUAGES CUDA CXX) project(transformer_engine LANGUAGES CUDA CXX)
set(BUILD_THREADS_PER_JOB $ENV{NVTE_BUILD_THREADS_PER_JOB})
if (NOT BUILD_THREADS_PER_JOB)
set(BUILD_THREADS_PER_JOB 1)
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --threads ${BUILD_THREADS_PER_JOB}")
if(DEFINED ENV{MAX_JOBS})
set(JOBS $ENV{MAX_JOBS})
elseif(DEFINED ENV{NVTE_BUILD_MAX_JOBS})
set(JOBS $ENV{NVTE_BUILD_MAX_JOBS})
else()
set(JOBS "max number of")
endif()
message(STATUS "Parallel build with ${JOBS} jobs and ${BUILD_THREADS_PER_JOB} threads per job")
if (CMAKE_BUILD_TYPE STREQUAL "Debug") if (CMAKE_BUILD_TYPE STREQUAL "Debug")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G") set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -G")
endif() endif()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment