nrtdynmod.py 7.31 KB
Newer Older
dugupeiwen's avatar
dugupeiwen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
"""
Dynamically generate the NRT module
"""


from numba.core import config
from numba.core import types, cgutils
from llvmlite import ir, binding


_word_type = ir.IntType(config.MACHINE_BITS)
_pointer_type = ir.PointerType(ir.IntType(8))

_meminfo_struct_type = ir.LiteralStructType([
    _word_type,     # size_t refct
    _pointer_type,  # dtor_function dtor
    _pointer_type,  # void *dtor_info
    _pointer_type,  # void *data
    _word_type,     # size_t size
    ])


incref_decref_ty = ir.FunctionType(ir.VoidType(), [_pointer_type])
meminfo_data_ty = ir.FunctionType(_pointer_type, [_pointer_type])


def _define_nrt_meminfo_data(module):
    """
    Implement NRT_MemInfo_data_fast in the module.  This allows LLVM
    to inline lookup of the data pointer.
    """
    fn = cgutils.get_or_insert_function(module, meminfo_data_ty,
                                        "NRT_MemInfo_data_fast")
    builder = ir.IRBuilder(fn.append_basic_block())
    [ptr] = fn.args
    struct_ptr = builder.bitcast(ptr, _meminfo_struct_type.as_pointer())
    data_ptr = builder.load(cgutils.gep(builder, struct_ptr, 0, 3))
    builder.ret(data_ptr)


def _define_nrt_incref(module, atomic_incr):
    """
    Implement NRT_incref in the module
    """
    fn_incref = cgutils.get_or_insert_function(module, incref_decref_ty,
                                              "NRT_incref")
    # Cannot inline this for refcount pruning to work
    fn_incref.attributes.add('noinline')
    builder = ir.IRBuilder(fn_incref.append_basic_block())
    [ptr] = fn_incref.args
    is_null = builder.icmp_unsigned("==", ptr, cgutils.get_null_value(ptr.type))
    with cgutils.if_unlikely(builder, is_null):
        builder.ret_void()

    word_ptr = builder.bitcast(ptr, atomic_incr.args[0].type)
    if config.DEBUG_NRT:
        cgutils.printf(builder, "*** NRT_Incref %zu [%p]\n", builder.load(word_ptr),
                       ptr)
    builder.call(atomic_incr, [word_ptr])
    builder.ret_void()


def _define_nrt_decref(module, atomic_decr):
    """
    Implement NRT_decref in the module
    """
    fn_decref = cgutils.get_or_insert_function(module, incref_decref_ty,
                                               "NRT_decref")
    # Cannot inline this for refcount pruning to work
    fn_decref.attributes.add('noinline')
    calldtor = ir.Function(module,
                           ir.FunctionType(ir.VoidType(), [_pointer_type]),
                           name="NRT_MemInfo_call_dtor")

    builder = ir.IRBuilder(fn_decref.append_basic_block())
    [ptr] = fn_decref.args
    is_null = builder.icmp_unsigned("==", ptr, cgutils.get_null_value(ptr.type))
    with cgutils.if_unlikely(builder, is_null):
        builder.ret_void()


    # For memory fence usage, see https://llvm.org/docs/Atomics.html

    # A release fence is used before the relevant write operation.
    # No-op on x86.  On POWER, it lowers to lwsync.
    builder.fence("release")

    word_ptr = builder.bitcast(ptr, atomic_decr.args[0].type)

    if config.DEBUG_NRT:
        cgutils.printf(builder, "*** NRT_Decref %zu [%p]\n", builder.load(word_ptr),
                       ptr)
    newrefct = builder.call(atomic_decr,
                            [word_ptr])

    refct_eq_0 = builder.icmp_unsigned("==", newrefct,
                                       ir.Constant(newrefct.type, 0))
    with cgutils.if_unlikely(builder, refct_eq_0):
        # An acquire fence is used after the relevant read operation.
        # No-op on x86.  On POWER, it lowers to lwsync.
        builder.fence("acquire")
        builder.call(calldtor, [ptr])
    builder.ret_void()


# Set this to True to measure the overhead of atomic refcounts compared
# to non-atomic.
_disable_atomicity = 0


def _define_atomic_inc_dec(module, op, ordering):
    """Define a llvm function for atomic increment/decrement to the given module
    Argument ``op`` is the operation "add"/"sub".  Argument ``ordering`` is
    the memory ordering.  The generated function returns the new value.
    """
    ftype = ir.FunctionType(_word_type, [_word_type.as_pointer()])
    fn_atomic = ir.Function(module, ftype, name="nrt_atomic_{0}".format(op))

    [ptr] = fn_atomic.args
    bb = fn_atomic.append_basic_block()
    builder = ir.IRBuilder(bb)
    ONE = ir.Constant(_word_type, 1)
    if not _disable_atomicity:
        oldval = builder.atomic_rmw(op, ptr, ONE, ordering=ordering)
        # Perform the operation on the old value so that we can pretend returning
        # the "new" value.
        res = getattr(builder, op)(oldval, ONE)
        builder.ret(res)
    else:
        oldval = builder.load(ptr)
        newval = getattr(builder, op)(oldval, ONE)
        builder.store(newval, ptr)
        builder.ret(oldval)

    return fn_atomic


def _define_atomic_cas(module, ordering):
    """Define a llvm function for atomic compare-and-swap.
    The generated function is a direct wrapper of the LLVM cmpxchg with the
    difference that the a int indicate success (1) or failure (0) is returned
    and the last argument is a output pointer for storing the old value.

    Note
    ----
    On failure, the generated function behaves like an atomic load.  The loaded
    value is stored to the last argument.
    """
    ftype = ir.FunctionType(ir.IntType(32), [_word_type.as_pointer(),
                                             _word_type, _word_type,
                                             _word_type.as_pointer()])
    fn_cas = ir.Function(module, ftype, name="nrt_atomic_cas")

    [ptr, cmp, repl, oldptr] = fn_cas.args
    bb = fn_cas.append_basic_block()
    builder = ir.IRBuilder(bb)
    outtup = builder.cmpxchg(ptr, cmp, repl, ordering=ordering)
    old, ok = cgutils.unpack_tuple(builder, outtup, 2)
    builder.store(old, oldptr)
    builder.ret(builder.zext(ok, ftype.return_type))

    return fn_cas


def _define_nrt_unresolved_abort(ctx, module):
    """
    Defines an abort function due to unresolved symbol.

    The function takes no args and will always raise an exception.
    It should be safe to call this function with incorrect number of arguments.
    """
    fnty = ctx.call_conv.get_function_type(types.none, ())
    fn = ir.Function(module, fnty, name="nrt_unresolved_abort")
    bb = fn.append_basic_block()
    builder = ir.IRBuilder(bb)
    msg = "numba jitted function aborted due to unresolved symbol"
    ctx.call_conv.return_user_exc(builder, RuntimeError, (msg,))
    return fn


def create_nrt_module(ctx):
    """
    Create an IR module defining the LLVM NRT functions.
    A (IR module, library) tuple is returned.
    """
    codegen = ctx.codegen()
    library = codegen.create_library("nrt")

    # Implement LLVM module with atomic ops
    ir_mod = library.create_ir_module("nrt_module")

    atomic_inc = _define_atomic_inc_dec(ir_mod, "add", ordering='monotonic')
    atomic_dec = _define_atomic_inc_dec(ir_mod, "sub", ordering='monotonic')
    _define_atomic_cas(ir_mod, ordering='monotonic')

    _define_nrt_meminfo_data(ir_mod)
    _define_nrt_incref(ir_mod, atomic_inc)
    _define_nrt_decref(ir_mod, atomic_dec)

    _define_nrt_unresolved_abort(ctx, ir_mod)

    return ir_mod, library


def compile_nrt_functions(ctx):
    """
    Compile all LLVM NRT functions and return a library containing them.
    The library is created using the given target context.
    """
    ir_mod, library = create_nrt_module(ctx)

    library.add_ir_module(ir_mod)
    library.finalize()

    return library