sparse_attn.py 1.94 KB
Newer Older
Samyam Rajbhandari's avatar
Samyam Rajbhandari committed
1
2
3
"""
Copyright 2020 The Microsoft DeepSpeed Team
"""
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import warnings
from .builder import OpBuilder


class SparseAttnBuilder(OpBuilder):
    BUILD_VAR = "DS_BUILD_SPARSE_ATTN"
    NAME = "sparse_attn"

    def __init__(self):
        super().__init__(name=self.NAME)

    def absolute_name(self):
        return f'deepspeed.ops.sparse_attention.{self.NAME}_op'

401qingkong's avatar
401qingkong committed
19
20
21
22
23
    def sources(self, is_rocm_pytorch):
        if is_rocm_pytorch:
            return ['csrc/sparse_attention/hip/utils.cpp']
        else:
            return ['csrc/sparse_attention/utils.cpp']
24
25
26

    def cxx_args(self):
        return ['-O2', '-fopenmp']
401qingkong's avatar
401qingkong committed
27
    
28
29
    def is_compatible(self):
        # Check to see if llvm and cmake are installed since they are dependencies
401qingkong's avatar
401qingkong committed
30
31
        #required_commands = ['llvm-config|llvm-config-9', 'cmake']
        required_commands = ['cmake']
32
33
34
        command_status = list(map(self.command_exists, required_commands))
        deps_compatible = all(command_status)

35
        # torch-cpu will not have a cuda version
401qingkong's avatar
401qingkong committed
36
        if torch.version.hip is None:
37
38
            cuda_compatible = False
            self.warning(f"{self.NAME} cuda is not available from torch")
401qingkong's avatar
401qingkong committed
39
40
41
42
43
44
45
        #else:
        #    major, minor = torch.version.cuda.split('.')[:2]
        #    cuda_compatible = int(major) == 10 and int(minor) >= 1
        #    if not cuda_compatible:
        #        self.warning(
        #            f"{self.NAME} requires CUDA version 10.1+, does not currently support >=11 or <10.1"
        #        )
46

47
48
49
50
51
52
53
54
        TORCH_MAJOR = int(torch.__version__.split('.')[0])
        TORCH_MINOR = int(torch.__version__.split('.')[1])
        torch_compatible = TORCH_MAJOR == 1 and TORCH_MINOR >= 5
        if not torch_compatible:
            self.warning(
                f'{self.NAME} requires a torch version >= 1.5 but detected {TORCH_MAJOR}.{TORCH_MINOR}'
            )

55
        return super().is_compatible(
401qingkong's avatar
401qingkong committed
56
        ) and deps_compatible and torch_compatible #and cuda_compatible