Unverified Commit 4ec66c77 authored by hXl3s's avatar hXl3s Committed by GitHub
Browse files

Let user limit number of architectures, to improve build time (#1126)



* Limit number of architectures build
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarLukasz Pierscieniewski <lukaszp@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 901e5d2b
......@@ -10,8 +10,9 @@ import setuptools
from .utils import (
all_files_in_dir,
cuda_version,
cuda_archs,
cuda_path,
cuda_version,
)
......@@ -48,8 +49,6 @@ def setup_pytorch_extension(
]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
......@@ -61,6 +60,11 @@ def setup_pytorch_extension(
"--use_fast_math",
]
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
# Version-dependent CUDA options
try:
version = cuda_version()
......@@ -73,13 +77,14 @@ def setup_pytorch_extension(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
"-gencode",
"arch=compute_80,code=sm_80",
"-gencode",
"arch=compute_90,code=sm_90",
)
)
if "80" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if "90" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# Libraries
library_dirs = []
libraries = []
......
......@@ -6,12 +6,12 @@
import functools
import glob
import importlib
import os
import re
import shutil
import subprocess
import sys
import importlib
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
......@@ -188,6 +188,11 @@ def cuda_path() -> Tuple[str, str]:
return cuda_home, nvcc_bin
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
return os.getenv("NVTE_CUDA_ARCHS", "70;80;89;90")
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
......
......@@ -13,17 +13,17 @@ import setuptools
from wheel.bdist_wheel import bdist_wheel
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.te_version import te_version
from build_tools.utils import (
cuda_archs,
found_cmake,
found_ninja,
found_pybind11,
remove_dups,
get_frameworks,
install_and_import,
remove_dups,
uninstall_te_fw_packages,
)
from build_tools.te_version import te_version
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
......@@ -59,10 +59,11 @@ def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library"""
# Project directory root
root_path = Path(__file__).resolve().parent
return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / Path("transformer_engine/common"),
cmake_flags=[],
cmake_flags=["-DCMAKE_CUDA_ARCHITECTURES={}".format(cuda_archs())],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment