cuda_extension.py 3.8 KB
Newer Older
1
import os
2
import time
3
from abc import abstractmethod
4
from pathlib import Path
5
6
from typing import List

7
from .base_extension import _Extension
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from .cpp_extension import _CppExtension
from .utils import check_pytorch_version, check_system_pytorch_cuda_match, set_cuda_arch_list

__all__ = ["_CudaExtension"]

# Some constants for installation checks
MIN_PYTORCH_VERSION_MAJOR = 1
MIN_PYTORCH_VERSION_MINOR = 10


class _CudaExtension(_CppExtension):
    @abstractmethod
    def nvcc_flags(self) -> List[str]:
        """
        This function should return a list of nvcc compilation flags for extensions.
        """

    def is_hardware_available(self) -> bool:
digger yu's avatar
digger yu committed
26
        # cuda extension can only be built if cuda is available
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        try:
            import torch

            cuda_available = torch.cuda.is_available()
        except:
            cuda_available = False
        return cuda_available

    def assert_hardware_compatible(self) -> None:
        from torch.utils.cpp_extension import CUDA_HOME

        if not CUDA_HOME:
            raise AssertionError(
                "[extension] CUDA_HOME is not found. You need to export CUDA_HOME environment variable or install CUDA Toolkit first in order to build/load CUDA extensions"
            )
        check_system_pytorch_cuda_match(CUDA_HOME)
        check_pytorch_version(MIN_PYTORCH_VERSION_MAJOR, MIN_PYTORCH_VERSION_MINOR)

    def get_cuda_home_include(self):
        """
        return include path inside the cuda home.
        """
        from torch.utils.cpp_extension import CUDA_HOME

        if CUDA_HOME is None:
            raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
        cuda_include = os.path.join(CUDA_HOME, "include")
        return cuda_include

    def build_jit(self) -> None:
        from torch.utils.cpp_extension import CUDA_HOME, load

        set_cuda_arch_list(CUDA_HOME)

        # get build dir
        build_directory = _Extension.get_jit_extension_folder_path()
        build_directory = Path(build_directory)
        build_directory.mkdir(parents=True, exist_ok=True)

        # check if the kernel has been built
        compiled_before = False
        kernel_file_path = build_directory.joinpath(f"{self.name}.o")
        if kernel_file_path.exists():
            compiled_before = True

        # load the kernel
        if compiled_before:
            print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
        else:
            print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")

        build_start = time.time()
        op_kernel = load(
            name=self.name,
            sources=self.strip_empty_entries(self.sources_files()),
            extra_include_paths=self.strip_empty_entries(self.include_dirs()),
            extra_cflags=self.cxx_flags(),
            extra_cuda_cflags=self.nvcc_flags(),
            extra_ldflags=[],
            build_directory=str(build_directory),
        )
        build_duration = time.time() - build_start

        if compiled_before:
            print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
        else:
            print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")

        return op_kernel

    def build_aot(self) -> "CUDAExtension":
        from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension

        set_cuda_arch_list(CUDA_HOME)
        return CUDAExtension(
            name=self.prebuilt_import_path,
            sources=self.strip_empty_entries(self.sources_files()),
            include_dirs=self.strip_empty_entries(self.include_dirs()),
            extra_compile_args={
                "cxx": self.strip_empty_entries(self.cxx_flags()),
                "nvcc": self.strip_empty_entries(self.nvcc_flags()),
            },
        )