cub.py 3.46 KB
Newer Older
root's avatar
root committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from cupyx.jit import _cuda_types
from cupyx.jit import _cuda_typerules
from cupyx.jit import _internal_types
from cupy_backends.cuda.api import runtime as _runtime


class _ClassTemplate:

    def __init__(self, class_type):
        self._class_type = class_type
        self.__doc__ = self._class_type.__doc__

    def __getitem__(self, args):
        if isinstance(args, tuple):
            return self._class_type(*args)
        else:
            return self._class_type(args)


def _include_cub(env):
    if _runtime.is_hip:
        env.generated.add_code('#include <hipcub/hipcub.hpp>')
    elif _runtime.runtimeGetVersion() < 11000:
        env.generated.add_code('#include <cupy/cub/cub/cub.cuh>')
    else:
        env.generated.add_code('#include <cub/cub.cuh>')
    env.generated.backend = 'nvcc'


def _get_cub_namespace():
    return 'hipcub' if _runtime.is_hip else 'cub'


class _TempStorageType(_cuda_types.TypeBase):

    def __init__(self, parent_type):
        assert isinstance(parent_type, _CubReduceBaseType)
        self.parent_type = parent_type
        super().__init__()

    def __str__(self) -> str:
        return f'typename {self.parent_type}::TempStorage'


class _CubReduceBaseType(_cuda_types.TypeBase):

    def _instantiate(self, env, temp_storage) -> _internal_types.Data:
        _include_cub(env)
        if temp_storage.ctype != self.TempStorage:
            raise TypeError(
                f'Invalid temp_storage type {temp_storage.ctype}. '
                f'({self.TempStorage} is expected.)')
        return _internal_types.Data(f'{self}({temp_storage.code})', self)

    @_internal_types.wraps_class_method
    def Sum(self, env, instance, input) -> _internal_types.Data:
        if input.ctype != self.T:
            raise TypeError(
                f'Invalid input type {input.ctype}. ({self.T} is expected.)')
        return _internal_types.Data(
            f'{instance.code}.Sum({input.code})', input.ctype)

    @_internal_types.wraps_class_method
    def Reduce(self, env, instance, input, reduction_op):
        if input.ctype != self.T:
            raise TypeError(
                f'Invalid input type {input.ctype}. ({self.T} is expected.)')
        return _internal_types.Data(
            f'{instance.code}.Reduce({input.code}, {reduction_op.code})',
            input.ctype)


class _WarpReduceType(_CubReduceBaseType):

    def __init__(self, T) -> None:
        self.T = _cuda_typerules.to_ctype(T)
        self.TempStorage = _TempStorageType(self)
        super().__init__()

    def __str__(self) -> str:
        namespace = _get_cub_namespace()
        return f'{namespace}::WarpReduce<{self.T}>'


class _BlockReduceType(_CubReduceBaseType):

    def __init__(self, T, BLOCK_DIM_X: int) -> None:
        self.T = _cuda_typerules.to_ctype(T)
        self.BLOCK_DIM_X = BLOCK_DIM_X
        self.TempStorage = _TempStorageType(self)
        super().__init__()

    def __str__(self) -> str:
        namespace = _get_cub_namespace()
        return f'{namespace}::BlockReduce<{self.T}, {self.BLOCK_DIM_X}>'


WarpReduce = _ClassTemplate(_WarpReduceType)
BlockReduce = _ClassTemplate(_BlockReduceType)


class _CubFunctor(_internal_types.BuiltinFunc):

    def __init__(self, name):
        namespace = _get_cub_namespace()
        self.fname = f'{namespace}::{name}()'

    def call_const(self, env):
        return _internal_types.Data(
            self.fname, _cuda_types.Unknown(label='cub_functor'))


Sum = _CubFunctor('Sum')
Max = _CubFunctor('Max')
Min = _CubFunctor('Min')