scaled_upper_triangle_masked_softmax.py 1.27 KB
Newer Older
1
2
3
4
5
6
7
8
9
from .builder import Builder
from .utils import append_nvcc_threads, get_cuda_cc_flag


class ScaledUpperTrainglemaskedSoftmaxBuilder(Builder):
    NAME = "scaled_upper_triangle_masked_softmax"
    PREBUILT_IMPORT_PATH = "colossalai._C.scaled_upper_triangle_masked_softmax"

    def __init__(self):
10
11
12
13
        super().__init__(
            name=ScaledUpperTrainglemaskedSoftmaxBuilder.NAME,
            prebuilt_import_path=ScaledUpperTrainglemaskedSoftmaxBuilder.PREBUILT_IMPORT_PATH,
        )
14
15

    def include_dirs(self):
16
        return [self.csrc_abs_path("kernels/include"), self.get_cuda_home_include()]
17
18
19
20

    def sources_files(self):
        ret = [
            self.csrc_abs_path(fname)
21
            for fname in ["scaled_upper_triang_masked_softmax.cpp", "scaled_upper_triang_masked_softmax_cuda.cu"]
22
23
24
25
        ]
        return ret

    def cxx_flags(self):
26
        return ["-O3"] + self.version_dependent_macros
27
28
29

    def nvcc_flags(self):
        extra_cuda_flags = [
30
31
32
33
            "-U__CUDA_NO_HALF_OPERATORS__",
            "-U__CUDA_NO_HALF_CONVERSIONS__",
            "--expt-relaxed-constexpr",
            "--expt-extended-lambda",
34
35
        ]
        extra_cuda_flags.extend(get_cuda_cc_flag())
36
        ret = ["-O3", "--use_fast_math"] + extra_cuda_flags
37
        return append_nvcc_threads(ret)