"...lm-evaluation-harness.git" did not exist on "3a67a226e288d5f5fa66065f428c5842d8ba3ac7"
Commit 1fb0017a authored by dugupeiwen's avatar dugupeiwen
Browse files

init 0.58

parents
/**
This is a modified version of capsulethunk.h for use in llvmpy
**/
#ifndef __CAPSULETHUNK_H
#define __CAPSULETHUNK_H
#if ( (PY_VERSION_HEX < 0x02070000) \
|| ((PY_VERSION_HEX >= 0x03000000) \
&& (PY_VERSION_HEX < 0x03010000)) )
//#define Assert(X) do_assert(!!(X), #X, __FILE__, __LINE__)
#define Assert(X)
static
void do_assert(int cond, const char * msg, const char *file, unsigned line){
if (!cond) {
fprintf(stderr, "Assertion failed %s:%d\n%s\n", file, line, msg);
exit(1);
}
}
typedef void (*PyCapsule_Destructor)(PyObject *);
struct FakePyCapsule_Desc {
const char *name;
void *context;
PyCapsule_Destructor dtor;
PyObject *parent;
FakePyCapsule_Desc() : name(0), context(0), dtor(0) {}
};
static
FakePyCapsule_Desc* get_pycobj_desc(PyObject *p){
void *desc = ((PyCObject*)p)->desc;
Assert(desc && "No desc in PyCObject");
return static_cast<FakePyCapsule_Desc*>(desc);
}
static
void pycobject_pycapsule_dtor(void *p, void *desc){
Assert(desc);
Assert(p);
FakePyCapsule_Desc *fpc_desc = static_cast<FakePyCapsule_Desc*>(desc);
Assert(fpc_desc->parent);
Assert(PyCObject_Check(fpc_desc->parent));
fpc_desc->dtor(static_cast<PyObject*>(fpc_desc->parent));
delete fpc_desc;
}
static
PyObject* PyCapsule_New(void* ptr, const char *name, PyCapsule_Destructor dtor)
{
FakePyCapsule_Desc *desc = new FakePyCapsule_Desc;
desc->name = name;
desc->context = NULL;
desc->dtor = dtor;
PyObject *p = PyCObject_FromVoidPtrAndDesc(ptr, desc,
pycobject_pycapsule_dtor);
desc->parent = p;
return p;
}
static
int PyCapsule_CheckExact(PyObject *p)
{
return PyCObject_Check(p);
}
static
void* PyCapsule_GetPointer(PyObject *p, const char *name)
{
Assert(PyCapsule_CheckExact(p));
if (strcmp(get_pycobj_desc(p)->name, name) != 0) {
PyErr_SetString(PyExc_ValueError, "Invalid PyCapsule object");
}
return PyCObject_AsVoidPtr(p);
}
static
void* PyCapsule_GetContext(PyObject *p)
{
Assert(p);
Assert(PyCapsule_CheckExact(p));
return get_pycobj_desc(p)->context;
}
static
int PyCapsule_SetContext(PyObject *p, void *context)
{
Assert(PyCapsule_CheckExact(p));
get_pycobj_desc(p)->context = context;
return 0;
}
static
const char * PyCapsule_GetName(PyObject *p)
{
// Assert(PyCapsule_CheckExact(p));
return get_pycobj_desc(p)->name;
}
#endif /* #if PY_VERSION_HEX < 0x02070000 */
#endif /* __CAPSULETHUNK_H */
"""
Utilities for getting information about Numba C extensions
"""
import os
def get_extension_libs():
"""Return the .c files in the `numba.cext` directory.
"""
libs = []
base = get_path()
for fn in os.listdir(base):
if fn.endswith('.c'):
fn = os.path.join(base, fn)
libs.append(fn)
return libs
def get_path():
"""Returns the path to the directory for `numba.cext`.
"""
return os.path.abspath(os.path.join(os.path.dirname(__file__)))
#ifndef NUMBA_EXTENSION_HELPER_H_
#define NUMBA_EXTENSION_HELPER_H_
#include "Python.h"
#include "../_numba_common.h"
/* Define all runtime-required symbols in this C module, but do not
export them outside the shared library if possible. */
#define NUMBA_EXPORT_FUNC(_rettype) VISIBILITY_HIDDEN _rettype
#define NUMBA_EXPORT_DATA(_vartype) VISIBILITY_HIDDEN _vartype
/* Use to declare a symbol as exported (global). */
#define NUMBA_GLOBAL_FUNC(_rettype) VISIBILITY_GLOBAL _rettype
NUMBA_EXPORT_FUNC(Py_ssize_t)
aligned_size(Py_ssize_t sz);
#include "dictobject.h"
#include "listobject.h"
#endif // end NUMBA_EXTENSION_HELPER_H_
/* The following is adapted from CPython3.7.
The exact commit is:
- https://github.com/python/cpython/blob/44467e8ea4cea390b0718702291b4cfe8ddd67ed/Objects/dictobject.c
*/
/* Dictionary object implementation using a hash table */
/* The distribution includes a separate file, Objects/dictnotes.txt,
describing explorations into dictionary design and optimization.
It covers typical dictionary use patterns, the parameters for
tuning dictionaries, and several ideas for possible optimizations.
*/
/* PyDictKeysObject
This implements the dictionary's hashtable.
As of Python 3.6, this is compact and ordered. Basic idea is described here:
* https://mail.python.org/pipermail/python-dev/2012-December/123028.html
* https://morepypy.blogspot.com/2015/01/faster-more-memory-efficient-and-more.html
layout:
+---------------+
| dk_refcnt |
| dk_size |
| dk_lookup |
| dk_usable |
| dk_nentries |
+---------------+
| dk_indices |
| |
+---------------+
| dk_entries |
| |
+---------------+
dk_indices is actual hashtable. It holds index in entries, or DKIX_EMPTY(-1)
or DKIX_DUMMY(-2).
Size of indices is dk_size. Type of each index in indices is vary on dk_size:
* int8 for dk_size <= 128
* int16 for 256 <= dk_size <= 2**15
* int32 for 2**16 <= dk_size <= 2**31
* int64 for 2**32 <= dk_size
dk_entries is array of PyDictKeyEntry. It's size is USABLE_FRACTION(dk_size).
DK_ENTRIES(dk) can be used to get pointer to entries.
NOTE: Since negative value is used for DKIX_EMPTY and DKIX_DUMMY, type of
dk_indices entry is signed integer and int16 is used for table which
dk_size == 256.
*/
/*
The DictObject can be in one of two forms.
Either:
A combined table:
ma_values == NULL, dk_refcnt == 1.
Values are stored in the me_value field of the PyDictKeysObject.
Or:
(Numba dev notes: split table logic is removed)
A split table:
ma_values != NULL, dk_refcnt >= 1
Values are stored in the ma_values array.
Only string (unicode) keys are allowed.
All dicts sharing same key must have same insertion order.
There are four kinds of slots in the table (slot is index, and
DK_ENTRIES(keys)[index] if index >= 0):
1. Unused. index == DKIX_EMPTY
Does not hold an active (key, value) pair now and never did. Unused can
transition to Active upon key insertion. This is each slot's initial state.
2. Active. index >= 0, me_key != NULL and me_value != NULL
Holds an active (key, value) pair. Active can transition to Dummy or
Pending upon key deletion (for combined and split tables respectively).
This is the only case in which me_value != NULL.
3. Dummy. index == DKIX_DUMMY (combined only)
Previously held an active (key, value) pair, but that was deleted and an
active pair has not yet overwritten the slot. Dummy can transition to
Active upon key insertion. Dummy slots cannot be made Unused again
else the probe sequence in case of collision would have no way to know
they were once active.
4. Pending. index >= 0, key != NULL, and value == NULL (split only)
Not yet inserted in split-table.
*/
/*
Preserving insertion order
It's simple for combined table. Since dk_entries is mostly append only, we can
get insertion order by just iterating dk_entries.
One exception is .popitem(). It removes last item in dk_entries and decrement
dk_nentries to achieve amortized O(1). Since there are DKIX_DUMMY remains in
dk_indices, we can't increment dk_usable even though dk_nentries is
decremented.
In split table, inserting into pending entry is allowed only for dk_entries[ix]
where ix == mp->ma_used. Inserting into other index and deleting item cause
converting the dict to the combined table.
*/
/* D_MINSIZE (adapted from PyDict_MINSIZE)
* is the starting size for any new dict.
* 8 allows dicts with no more than 5 active entries; experiments suggested
* this suffices for the majority of dicts (consisting mostly of usually-small
* dicts created to pass keyword arguments).
* Making this 8, rather than 4 reduces the number of resizes for most
* dictionaries, without any significant extra memory use.
*/
#define D_MINSIZE 8
#include "dictobject.h"
#if defined(_MSC_VER)
# if _MSC_VER <= 1900 /* Visual Studio 2014 */
typedef __int8 int8_t;
typedef __int16 int16_t;
typedef __int32 int32_t;
typedef __int64 int64_t;
# endif
/* Use _alloca() to dynamically allocate on the stack on MSVC */
#define STACK_ALLOC(Type, Name, Size) Type * const Name = _alloca(Size);
#else
#define STACK_ALLOC(Type, Name, Size) Type Name[Size];
#endif
/*[clinic input]
class dict "PyDictObject *" "&PyDict_Type"
[clinic start generated code]*/
/*[clinic end generated code: output=da39a3ee5e6b4b0d input=f157a5a0ce9589d6]*/
/*
To ensure the lookup algorithm terminates, there must be at least one Unused
slot (NULL key) in the table.
To avoid slowing down lookups on a near-full table, we resize the table when
it's USABLE_FRACTION (currently two-thirds) full.
*/
#define PERTURB_SHIFT 5
/*
Major subtleties ahead: Most hash schemes depend on having a "good" hash
function, in the sense of simulating randomness. Python doesn't: its most
important hash functions (for ints) are very regular in common
cases:
>>>[hash(i) for i in range(4)]
[0, 1, 2, 3]
This isn't necessarily bad! To the contrary, in a table of size 2**i, taking
the low-order i bits as the initial table index is extremely fast, and there
are no collisions at all for dicts indexed by a contiguous range of ints. So
this gives better-than-random behavior in common cases, and that's very
desirable.
OTOH, when collisions occur, the tendency to fill contiguous slices of the
hash table makes a good collision resolution strategy crucial. Taking only
the last i bits of the hash code is also vulnerable: for example, consider
the list [i << 16 for i in range(20000)] as a set of keys. Since ints are
their own hash codes, and this fits in a dict of size 2**15, the last 15 bits
of every hash code are all 0: they *all* map to the same table index.
But catering to unusual cases should not slow the usual ones, so we just take
the last i bits anyway. It's up to collision resolution to do the rest. If
we *usually* find the key we're looking for on the first try (and, it turns
out, we usually do -- the table load factor is kept under 2/3, so the odds
are solidly in our favor), then it makes best sense to keep the initial index
computation dirt cheap.
The first half of collision resolution is to visit table indices via this
recurrence:
j = ((5*j) + 1) mod 2**i
For any initial j in range(2**i), repeating that 2**i times generates each
int in range(2**i) exactly once (see any text on random-number generation for
proof). By itself, this doesn't help much: like linear probing (setting
j += 1, or j -= 1, on each loop trip), it scans the table entries in a fixed
order. This would be bad, except that's not the only thing we do, and it's
actually *good* in the common cases where hash keys are consecutive. In an
example that's really too small to make this entirely clear, for a table of
size 2**3 the order of indices is:
0 -> 1 -> 6 -> 7 -> 4 -> 5 -> 2 -> 3 -> 0 [and here it's repeating]
If two things come in at index 5, the first place we look after is index 2,
not 6, so if another comes in at index 6 the collision at 5 didn't hurt it.
Linear probing is deadly in this case because there the fixed probe order
is the *same* as the order consecutive keys are likely to arrive. But it's
extremely unlikely hash codes will follow a 5*j+1 recurrence by accident,
and certain that consecutive hash codes do not.
The other half of the strategy is to get the other bits of the hash code
into play. This is done by initializing a (unsigned) vrbl "perturb" to the
full hash code, and changing the recurrence to:
perturb >>= PERTURB_SHIFT;
j = (5*j) + 1 + perturb;
use j % 2**i as the next table index;
Now the probe sequence depends (eventually) on every bit in the hash code,
and the pseudo-scrambling property of recurring on 5*j+1 is more valuable,
because it quickly magnifies small differences in the bits that didn't affect
the initial index. Note that because perturb is unsigned, if the recurrence
is executed often enough perturb eventually becomes and remains 0. At that
point (very rarely reached) the recurrence is on (just) 5*j+1 again, and
that's certain to find an empty slot eventually (since it generates every int
in range(2**i), and we make sure there's always at least one empty slot).
Selecting a good value for PERTURB_SHIFT is a balancing act. You want it
small so that the high bits of the hash code continue to affect the probe
sequence across iterations; but you want it large so that in really bad cases
the high-order hash bits have an effect on early iterations. 5 was "the
best" in minimizing total collisions across experiments Tim Peters ran (on
both normal and pathological cases), but 4 and 6 weren't significantly worse.
Historical: Reimer Behrends contributed the idea of using a polynomial-based
approach, using repeated multiplication by x in GF(2**n) where an irreducible
polynomial for each table size was chosen such that x was a primitive root.
Christian Tismer later extended that to use division by x instead, as an
efficient way to get the high bits of the hash code into play. This scheme
also gave excellent collision statistics, but was more expensive: two
if-tests were required inside the loop; computing "the next" index took about
the same number of operations but without as much potential parallelism
(e.g., computing 5*j can go on at the same time as computing 1+perturb in the
above, and then shifting perturb can be done while the table index is being
masked); and the PyDictObject struct required a member to hold the table's
polynomial. In Tim's experiments the current scheme ran faster, produced
equally good collision statistics, needed less code & used less memory.
*/
#define DKIX_EMPTY (-1)
#define DKIX_DUMMY (-2) /* Used internally */
#define DKIX_ERROR (-3)
typedef enum {
OK = 0,
OK_REPLACED = 1,
ERR_NO_MEMORY = -1,
ERR_DICT_MUTATED = -2,
ERR_ITER_EXHAUSTED = -3,
ERR_DICT_EMPTY = -4,
ERR_CMP_FAILED = -5,
} Status;
#ifndef NDEBUG
static
int mem_cmp_zeros(void *obj, size_t n){
int diff = 0;
char *mem = obj;
char *it;
for (it = mem; it < mem + n; ++it) {
if (*it != 0) diff += 1;
}
return diff;
}
#endif
#define D_MASK(dk) ((dk)->size-1)
#define D_GROWTH_RATE(d) ((d)->used*3)
static int
ix_size(Py_ssize_t size) {
if ( size < 0xff ) return 1;
if ( size < 0xffff ) return 2;
if ( size < 0xffffffff ) return 4;
return sizeof(int64_t);
}
#ifndef NDEBUG
/* NOTE: This function is only used in assert()s */
/* Align pointer *ptr* to pointer size */
static void*
aligned_pointer(void *ptr) {
return (void*)aligned_size((size_t)ptr);
}
#endif
/* lookup indices. returns DKIX_EMPTY, DKIX_DUMMY, or ix >=0 */
static Py_ssize_t
get_index(NB_DictKeys *dk, Py_ssize_t i)
{
Py_ssize_t s = dk->size;
Py_ssize_t ix;
if (s <= 0xff) {
int8_t *indices = (int8_t*)(dk->indices);
assert (i < dk->size);
ix = indices[i];
}
else if (s <= 0xffff) {
int16_t *indices = (int16_t*)(dk->indices);
ix = indices[i];
}
#if SIZEOF_VOID_P > 4
else if (s > 0xffffffff) {
int64_t *indices = (int64_t*)(dk->indices);
ix = indices[i];
}
#endif
else {
int32_t *indices = (int32_t*)(dk->indices);
ix = indices[i];
}
assert(ix >= DKIX_DUMMY);
return ix;
}
/* write to indices. */
static void
set_index(NB_DictKeys *dk, Py_ssize_t i, Py_ssize_t ix)
{
Py_ssize_t s = dk->size;
assert(ix >= DKIX_DUMMY);
if (s <= 0xff) {
int8_t *indices = (int8_t*)(dk->indices);
assert(ix <= 0x7f);
indices[i] = (char)ix;
}
else if (s <= 0xffff) {
int16_t *indices = (int16_t*)(dk->indices);
assert(ix <= 0x7fff);
indices[i] = (int16_t)ix;
}
#if SIZEOF_VOID_P > 4
else if (s > 0xffffffff) {
int64_t *indices = (int64_t*)(dk->indices);
indices[i] = ix;
}
#endif
else {
int32_t *indices = (int32_t*)(dk->indices);
assert(ix <= 0x7fffffff);
indices[i] = (int32_t)ix;
}
}
/* USABLE_FRACTION is the maximum dictionary load.
* Increasing this ratio makes dictionaries more dense resulting in more
* collisions. Decreasing it improves sparseness at the expense of spreading
* indices over more cache lines and at the cost of total memory consumed.
*
* USABLE_FRACTION must obey the following:
* (0 < USABLE_FRACTION(n) < n) for all n >= 2
*
* USABLE_FRACTION should be quick to calculate.
* Fractions around 1/2 to 2/3 seem to work well in practice.
*/
#define USABLE_FRACTION(n) (((n) << 1)/3) // ratio: 2/3
/* Alternative fraction that is otherwise close enough to 2n/3 to make
* little difference. 8 * 2/3 == 8 * 5/8 == 5. 16 * 2/3 == 16 * 5/8 == 10.
* 32 * 2/3 = 21, 32 * 5/8 = 20.
* Its advantage is that it is faster to compute on machines with slow division.
* #define USABLE_FRACTION(n) (((n) >> 1) + ((n) >> 2) - ((n) >> 3)) // ratio: 5/8
*/
/* INV_USABLE_FRACTION gives the inverse of USABLE_FRACTION.
* Used for sizing a new dictionary to a specified number of keys.
*
* NOTE: If the denominator of the USABLE_FRACTION ratio is not a power
* of 2, must add 1 to the result of the inverse for correct sizing.
*
* For example, when USABLE_FRACTION ratio = 5/8 (8 is a power of 2):
* #define INV_USABLE_FRACTION(n) (((n) << 3)/5) // inv_ratio: 8/5
*
* When USABLE_FRACTION ratio = 5/7 (7 is not a power of 2):
* #define INV_USABLE_FRACTION(n) ((7*(n))/5 + 1) // inv_ratio: 7/5
*/
#define INV_USABLE_FRACTION(n) ((n) + ((n) >> 1) + 1) // inv_ratio: 3/2
/* GROWTH_RATE. Growth rate upon hitting maximum load.
* Currently set to used*3.
* This means that dicts double in size when growing without deletions,
* but have more head room when the number of deletions is on a par with the
* number of insertions. See also bpo-17563 and bpo-33205.
*
* GROWTH_RATE was set to used*4 up to version 3.2.
* GROWTH_RATE was set to used*2 in version 3.3.0
* GROWTH_RATE was set to used*2 + capacity/2 in 3.4.0-3.6.0.
*/
#define GROWTH_RATE(d) ((d)->ma_used*3)
static NB_DictEntry*
get_entry(NB_DictKeys *dk, Py_ssize_t idx) {
Py_ssize_t offset;
char *ptr;
assert (idx < dk->size);
offset = idx * dk->entry_size;
ptr = dk->indices + dk->entry_offset + offset;
return (NB_DictEntry*)ptr;
}
static void
zero_key(NB_DictKeys *dk, char *data){
memset(data, 0, dk->key_size);
}
static void
zero_val(NB_DictKeys *dk, char *data){
memset(data, 0, dk->val_size);
}
static void
copy_key(NB_DictKeys *dk, char *dst, const char *src){
memcpy(dst, src, dk->key_size);
}
static void
copy_val(NB_DictKeys *dk, char *dst, const char *src){
memcpy(dst, src, dk->val_size);
}
/* Returns -1 for error; 0 for not equal; 1 for equal */
static int
key_equal(NB_DictKeys *dk, const char *lhs, const char *rhs) {
if ( dk->methods.key_equal ) {
return dk->methods.key_equal(lhs, rhs);
} else {
return memcmp(lhs, rhs, dk->key_size) == 0;
}
}
static char *
entry_get_key(NB_DictKeys *dk, NB_DictEntry* entry) {
char * out = entry->keyvalue;
assert (out == aligned_pointer(out));
return out;
}
static char *
entry_get_val(NB_DictKeys *dk, NB_DictEntry* entry) {
char * out = entry_get_key(dk, entry) + aligned_size(dk->key_size);
assert (out == aligned_pointer(out));
return out;
}
static void
dk_incref_key(NB_DictKeys *dk, const char *key) {
if ( dk->methods.key_incref ) {
dk->methods.key_incref(key);
}
}
static void
dk_decref_key(NB_DictKeys *dk, const char *key) {
if ( dk->methods.key_decref ) {
dk->methods.key_decref(key);
}
}
static void
dk_incref_val(NB_DictKeys *dk, const char *val) {
if ( dk->methods.value_incref ) {
dk->methods.value_incref(val);
}
}
static void
dk_decref_val(NB_DictKeys *dk, const char *val) {
if ( dk->methods.value_decref ) {
dk->methods.value_decref(val);
}
}
void
numba_dictkeys_free(NB_DictKeys *dk) {
/* Clear all references from the entries */
Py_ssize_t i;
NB_DictEntry *ep;
for (i = 0; i < dk->nentries; i++) {
ep = get_entry(dk, i);
if (ep->hash != DKIX_EMPTY) {
dk_decref_key(dk, entry_get_key(dk, ep));
dk_decref_val(dk, entry_get_val(dk, ep));
}
}
/* Deallocate */
free(dk);
}
void
numba_dict_free(NB_Dict *d) {
numba_dictkeys_free(d->keys);
free(d);
}
Py_ssize_t
numba_dict_length(NB_Dict *d) {
return d->used;
}
/* Allocate new dictionary keys
Adapted from CPython's new_keys_object().
*/
int
numba_dictkeys_new(NB_DictKeys **out, Py_ssize_t size, Py_ssize_t key_size, Py_ssize_t val_size) {
Py_ssize_t usable = USABLE_FRACTION(size);
Py_ssize_t index_size = ix_size(size);
Py_ssize_t entry_size = aligned_size(sizeof(NB_DictEntry) + aligned_size(key_size) + aligned_size(val_size));
Py_ssize_t entry_offset = aligned_size(index_size * size);
Py_ssize_t alloc_size = sizeof(NB_DictKeys) + entry_offset + entry_size * usable;
NB_DictKeys *dk = malloc(aligned_size(alloc_size));
if (!dk) return ERR_NO_MEMORY;
assert ( size >= D_MINSIZE );
dk->size = size;
dk->usable = usable;
dk->nentries = 0;
dk->key_size = key_size;
dk->val_size = val_size;
dk->entry_offset = entry_offset;
dk->entry_size = entry_size;
assert (aligned_pointer(dk->indices) == dk->indices );
/* Ensure that the method table is all nulls */
memset(&dk->methods, 0x00, sizeof(type_based_methods_table));
/* Ensure hash is (-1) for empty entry */
memset(dk->indices, 0xff, entry_offset + entry_size * usable);
*out = dk;
return OK;
}
/* Allocate new dictionary */
int
numba_dict_new(NB_Dict **out, Py_ssize_t size, Py_ssize_t key_size, Py_ssize_t val_size) {
NB_DictKeys *dk;
NB_Dict *d;
int status = numba_dictkeys_new(&dk, size, key_size, val_size);
if (status != OK) return status;
d = malloc(sizeof(NB_Dict));
if (!d) {
numba_dictkeys_free(dk);
return ERR_NO_MEMORY;
}
d->used = 0;
d->keys = dk;
*out = d;
return OK;
}
/*
Adapted from CPython lookdict_index().
Search index of hash table from offset of entry table
*/
static Py_ssize_t
lookdict_index(NB_DictKeys *dk, Py_hash_t hash, Py_ssize_t index)
{
size_t mask = D_MASK(dk);
size_t perturb = (size_t)hash;
size_t i = (size_t)hash & mask;
for (;;) {
Py_ssize_t ix = get_index(dk, i);
if (ix == index) {
return i;
}
if (ix == DKIX_EMPTY) {
return DKIX_EMPTY;
}
perturb >>= PERTURB_SHIFT;
i = mask & (i*5 + perturb + 1);
}
assert(0 && "unreachable");
}
/*
Adapted from the CPython3.7 lookdict().
The basic lookup function used by all operations.
This is based on Algorithm D from Knuth Vol. 3, Sec. 6.4.
Open addressing is preferred over chaining since the link overhead for
chaining would be substantial (100% with typical malloc overhead).
The initial probe index is computed as hash mod the table size. Subsequent
probe indices are computed as explained earlier.
All arithmetic on hash should ignore overflow.
The details in this version are due to Tim Peters, building on many past
contributions by Reimer Behrends, Jyrki Alakuijala, Vladimir Marangozov and
Christian Tismer.
lookdict() is general-purpose, and may return DKIX_ERROR if (and only if) a
comparison raises an exception.
lookdict_unicode() below is specialized to string keys, comparison of which can
never raise an exception; that function can never return DKIX_ERROR when key
is string. Otherwise, it falls back to lookdict().
lookdict_unicode_nodummy is further specialized for string keys that cannot be
the <dummy> value.
For both, when the key isn't found a DKIX_EMPTY is returned.
*/
Py_ssize_t
numba_dict_lookup(NB_Dict *d, const char *key_bytes, Py_hash_t hash, char *oldval_bytes)
{
NB_DictKeys *dk = d->keys;
size_t mask = D_MASK(dk);
size_t perturb = hash;
size_t i = (size_t)hash & mask;
for (;;) {
Py_ssize_t ix = get_index(dk, i);
if (ix == DKIX_EMPTY) {
zero_val(dk, oldval_bytes);
return ix;
}
if (ix >= 0) {
NB_DictEntry *ep = get_entry(dk, ix);
const char *startkey = NULL;
if (ep->hash == hash) {
int cmp;
startkey = entry_get_key(dk, ep);
cmp = key_equal(dk, startkey, key_bytes);
if (cmp < 0) {
// error'ed in comparison
memset(oldval_bytes, 0, dk->val_size);
return DKIX_ERROR;
}
if (cmp > 0) {
// key is equal; retrieve the value.
copy_val(dk, oldval_bytes, entry_get_val(dk, ep));
return ix;
}
}
}
perturb >>= PERTURB_SHIFT;
i = (i*5 + perturb + 1) & mask;
}
assert(0 && "unreachable");
}
/* Internal function to find slot for an item from its hash
when it is known that the key is not present in the dict.
The dict must be combined. */
static Py_ssize_t
find_empty_slot(NB_DictKeys *dk, Py_hash_t hash){
size_t mask;
size_t i;
Py_ssize_t ix;
size_t perturb;
assert(dk != NULL);
mask = D_MASK(dk);
i = hash & mask;
ix = get_index(dk, i);
for (perturb = hash; ix >= 0;) {
perturb >>= PERTURB_SHIFT;
i = (i*5 + perturb + 1) & mask;
ix = get_index(dk, i);
}
return i;
}
static int
insertion_resize(NB_Dict *d)
{
return numba_dict_resize(d, D_GROWTH_RATE(d));
}
int
numba_dict_insert(
NB_Dict *d,
const char *key_bytes,
Py_hash_t hash,
const char *val_bytes,
char *oldval_bytes
)
{
NB_DictKeys *dk = d->keys;
Py_ssize_t ix = numba_dict_lookup(d, key_bytes, hash, oldval_bytes);
if (ix == DKIX_ERROR) {
// exception in key comparison in lookup.
return ERR_CMP_FAILED;
}
if (ix == DKIX_EMPTY) {
/* Insert into new slot */
Py_ssize_t hashpos;
NB_DictEntry *ep;
if (dk->usable <= 0) {
/* Need to resize */
if (insertion_resize(d) != OK)
return ERR_NO_MEMORY;
else
dk = d->keys; // reload
}
hashpos = find_empty_slot(dk, hash);
ep = get_entry(dk, dk->nentries);
set_index(dk, hashpos, dk->nentries);
copy_key(dk, entry_get_key(dk, ep), key_bytes);
assert ( hash != -1 );
ep->hash = hash;
copy_val(dk, entry_get_val(dk, ep), val_bytes);
/* incref */
dk_incref_key(dk, key_bytes);
dk_incref_val(dk, val_bytes);
d->used += 1;
dk->usable -= 1;
dk->nentries += 1;
assert (dk->usable >= 0);
return OK;
} else {
/* Replace existing value in the slot at ix */
/* decref old value */
dk_decref_val(dk, oldval_bytes);
// Replace the previous value
copy_val(dk, entry_get_val(dk, get_entry(dk, ix)), val_bytes);
/* incref */
dk_incref_val(dk, val_bytes);
return OK_REPLACED;
}
}
/*
Adapted from build_indices().
Internal routine used by dictresize() to build a hashtable of entries.
*/
void
build_indices(NB_DictKeys *keys, Py_ssize_t n) {
size_t mask = (size_t)D_MASK(keys);
Py_ssize_t ix;
for (ix = 0; ix != n; ix++) {
size_t perturb;
Py_hash_t hash = get_entry(keys, ix)->hash;
size_t i = hash & mask;
for (perturb = hash; get_index(keys, i) != DKIX_EMPTY;) {
perturb >>= PERTURB_SHIFT;
i = mask & (i*5 + perturb + 1);
}
set_index(keys, i, ix);
}
}
/*
Adapted from CPython dictresize().
Restructure the table by allocating a new table and reinserting all
items again. When entries have been deleted, the new table may
actually be smaller than the old one.
If a table is split (its keys and hashes are shared, its values are not),
then the values are temporarily copied into the table, it is resized as
a combined table, then the me_value slots in the old table are NULLed out.
After resizing a table is always combined,
but can be resplit by make_keys_shared().
*/
int
numba_dict_resize(NB_Dict *d, Py_ssize_t minsize) {
Py_ssize_t newsize, numentries;
NB_DictKeys *oldkeys;
int status;
/* Find the smallest table size > minused. */
for (newsize = D_MINSIZE;
newsize < minsize && newsize > 0;
newsize <<= 1)
;
if (newsize <= 0) {
return ERR_NO_MEMORY;
}
oldkeys = d->keys;
/* NOTE: Current odict checks mp->ma_keys to detect resize happen.
* So we can't reuse oldkeys even if oldkeys->dk_size == newsize.
* TODO: Try reusing oldkeys when reimplement odict.
*/
/* Allocate a new table. */
status = numba_dictkeys_new(
&d->keys, newsize, oldkeys->key_size, oldkeys->val_size
);
if (status != OK) {
d->keys = oldkeys;
return status;
}
// New table must be large enough.
assert(d->keys->usable >= d->used);
// Copy method table
memcpy(&d->keys->methods, &oldkeys->methods, sizeof(type_based_methods_table));
numentries = d->used;
if (oldkeys->nentries == numentries) {
NB_DictEntry *oldentries, *newentries;
oldentries = get_entry(oldkeys, 0);
newentries = get_entry(d->keys, 0);
memcpy(newentries, oldentries, numentries * oldkeys->entry_size);
// to avoid decref
memset(oldentries, 0xff, numentries * oldkeys->entry_size);
}
else {
Py_ssize_t i;
size_t epi = 0;
for (i=0; i<numentries; ++i) {
/*
ep->hash == (-1) hash means it is empty
Here, we skip until a non empty entry is encountered.
*/
while( get_entry(oldkeys, epi)->hash == DKIX_EMPTY ) {
assert( mem_cmp_zeros(entry_get_val(oldkeys, get_entry(oldkeys, epi)), oldkeys->val_size) == 0 );
epi += 1;
}
memcpy(
get_entry(d->keys, i),
get_entry(oldkeys, epi),
oldkeys->entry_size
);
get_entry(oldkeys, epi)->hash = DKIX_EMPTY; // to avoid decref
epi += 1;
}
}
numba_dictkeys_free(oldkeys);
build_indices(d->keys, numentries);
d->keys->usable -= numentries;
d->keys->nentries = numentries;
return OK;
}
/*
Adapted from CPython delitem_common
*/
int
numba_dict_delitem(NB_Dict *d, Py_hash_t hash, Py_ssize_t ix)
{
Py_ssize_t hashpos;
NB_DictEntry *ep;
NB_DictKeys *dk = d->keys;
hashpos = lookdict_index(dk, hash, ix);
assert(hashpos >= 0);
d->used -= 1;
ep = get_entry(dk, ix);
set_index(dk, hashpos, DKIX_DUMMY);
/* decref */
dk_decref_key(dk, entry_get_key(dk, ep));
dk_decref_val(dk, entry_get_val(dk, ep));
/* zero the entries */
zero_key(dk, entry_get_key(dk, ep));
zero_val(dk, entry_get_val(dk, ep));
ep->hash = DKIX_EMPTY; // to mark it as empty;
return OK;
}
/**
* Adapted from dict_popitem
*
*/
int
numba_dict_popitem(NB_Dict *d, char *key_bytes, char *val_bytes)
{
Py_ssize_t i, j;
char *key_ptr, *val_ptr;
NB_DictEntry *ep = NULL;
if (d->used == 0) {
return ERR_DICT_EMPTY;
}
/* Pop last item */
i = d->keys->nentries - 1;
while (i >= 0 && (ep = get_entry(d->keys, i))->hash == DKIX_EMPTY ) {
i--;
}
assert(i >= 0);
j = lookdict_index(d->keys, ep->hash, i);
assert(j >= 0);
assert(get_index(d->keys, j) == i);
set_index(d->keys, j, DKIX_DUMMY);
key_ptr = entry_get_key(d->keys, ep);
val_ptr = entry_get_val(d->keys, ep);
copy_key(d->keys, key_bytes, key_ptr);
copy_val(d->keys, val_bytes, val_ptr);
zero_key(d->keys, key_ptr);
zero_val(d->keys, val_ptr);
/* We can't dk_usable++ since there is DKIX_DUMMY in indices */
d->keys->nentries = i;
d->used--;
return OK;
}
void
numba_dict_dump(NB_Dict *d) {
long long i, j, k;
long long size, n;
char *cp;
NB_DictEntry *ep;
NB_DictKeys *dk = d->keys;
n = d->used;
size = dk->nentries;
printf("Dict dump\n");
printf(" key_size = %lld\n", (long long)d->keys->key_size);
printf(" val_size = %lld\n", (long long)d->keys->val_size);
for (i = 0, j = 0; i < size; i++) {
ep = get_entry(dk, i);
if (ep->hash != DKIX_EMPTY) {
long long hash = ep->hash;
printf(" key=");
for (cp=entry_get_key(dk, ep), k=0; k < d->keys->key_size; ++k, ++cp){
printf("%02x ", ((int)*cp) & 0xff);
}
printf(" hash=%llu value=", hash);
for (cp=entry_get_val(dk, ep), k=0; k < d->keys->val_size; ++k, ++cp){
printf("%02x ", ((int)*cp) & 0xff);
}
printf("\n");
j++;
}
}
printf("j = %lld; n = %lld\n", j, n);
assert(j == n);
}
size_t
numba_dict_iter_sizeof() {
return sizeof(NB_DictIter);
}
void
numba_dict_iter(NB_DictIter *it, NB_Dict *d) {
it->parent = d;
it->parent_keys = d->keys;
it->size = d->used;
it->pos = 0;
}
int
numba_dict_iter_next(NB_DictIter *it, const char **key_ptr, const char **val_ptr) {
/* Detect dictionary mutation during iteration */
NB_DictKeys *dk;
if (it->parent->keys != it->parent_keys ||
it->parent->used != it->size) {
return ERR_DICT_MUTATED;
}
dk = it->parent_keys;
while ( it->pos < dk->nentries ) {
NB_DictEntry *ep = get_entry(dk, it->pos++);
if ( ep->hash != DKIX_EMPTY ) {
*key_ptr = entry_get_key(dk, ep);
*val_ptr = entry_get_val(dk, ep);
return OK;
}
}
return ERR_ITER_EXHAUSTED;
}
int
numba_dict_insert_ez(
NB_Dict *d,
const char *key_bytes,
Py_hash_t hash,
const char *val_bytes
)
{
STACK_ALLOC(char, old, d->keys->val_size);
return numba_dict_insert(d, key_bytes, hash, val_bytes, old);
}
/* Allocate a new dictionary with enough space to hold n_keys without resizes */
int
numba_dict_new_sized(NB_Dict **out, Py_ssize_t n_keys, Py_ssize_t key_size, Py_ssize_t val_size) {
/* Respect D_MINSIZE */
if (n_keys <= USABLE_FRACTION(D_MINSIZE)) {
return numba_dict_new(out, D_MINSIZE, key_size, val_size);
}
/* Adjust for load factor */
Py_ssize_t size = INV_USABLE_FRACTION(n_keys) - 1;
/* Round up size to the nearest power of 2. */
for (unsigned int shift = 1; shift < sizeof(Py_ssize_t) * CHAR_BIT; shift <<= 1) {
size |= (size >> shift);
}
size++;
/* Handle overflows */
if (size <= 0) {
return ERR_NO_MEMORY;
}
return numba_dict_new(out, size, key_size, val_size);
}
void
numba_dict_set_method_table(NB_Dict *d, type_based_methods_table *methods)
{
memcpy(&d->keys->methods, methods, sizeof(type_based_methods_table));
}
#define CHECK(CASE) { \
if ( !(CASE) ) { \
printf("'%s' failed file %s:%d\n", #CASE, __FILE__, __LINE__); \
return 1; \
} \
}
int
numba_test_dict(void) {
NB_Dict *d;
int status;
Py_ssize_t ix;
Py_ssize_t usable;
Py_ssize_t it_count;
const char *it_key, *it_val;
NB_DictIter iter;
#if defined(_MSC_VER)
/* So that VS2008 compiler is happy */
char *got_key, *got_value;
got_key = _alloca(4);
got_value = _alloca(8);
#else
char got_key[4];
char got_value[8];
#endif
puts("test_dict");
status = numba_dict_new(&d, D_MINSIZE, 4, 8);
CHECK(status == OK);
CHECK(d->keys->size == D_MINSIZE);
CHECK(d->keys->key_size == 4);
CHECK(d->keys->val_size == 8);
CHECK(ix_size(d->keys->size) == 1);
printf("aligned_size(index_size * size) = %d\n", (int)(aligned_size(ix_size(d->keys->size) * d->keys->size)));
printf("d %p\n", d);
printf("d->usable = %u\n", (int)d->keys->usable);
usable = d->keys->usable;
printf("d[0] %d\n", (int)((char*)get_entry(d->keys, 0) - (char*)d->keys));
CHECK ((char*)get_entry(d->keys, 0) - (char*)d->keys->indices == d->keys->entry_offset);
printf("d[1] %d\n", (int)((char*)get_entry(d->keys, 1) - (char*)d->keys));
CHECK ((char*)get_entry(d->keys, 1) - (char*)d->keys->indices == d->keys->entry_offset + d->keys->entry_size);
ix = numba_dict_lookup(d, "bef", 0xbeef, got_value);
printf("ix = %d\n", (int)ix);
CHECK (ix == DKIX_EMPTY);
// insert 1st key
status = numba_dict_insert(d, "bef", 0xbeef, "1234567", got_value);
CHECK (status == OK);
CHECK (d->used == 1);
CHECK (d->keys->usable == usable - d->used);
// insert same key
status = numba_dict_insert(d, "bef", 0xbeef, "1234567", got_value);
CHECK (status == OK_REPLACED);
printf("got_value %s\n", got_value);
CHECK (d->used == 1);
CHECK (d->keys->usable == usable - d->used);
// insert 2nd key
status = numba_dict_insert(d, "beg", 0xbeef, "1234568", got_value);
CHECK (status == OK);
CHECK (d->used == 2);
CHECK (d->keys->usable == usable - d->used);
// insert 3rd key
status = numba_dict_insert(d, "beh", 0xcafe, "1234569", got_value);
CHECK (status == OK);
CHECK (d->used == 3);
CHECK (d->keys->usable == usable - d->used);
// replace key "bef"'s value
status = numba_dict_insert(d, "bef", 0xbeef, "7654321", got_value);
CHECK (status == OK_REPLACED);
CHECK (d->used == 3);
CHECK (d->keys->usable == usable - d->used);
// insert 4th key
status = numba_dict_insert(d, "bei", 0xcafe, "0_0_0_1", got_value);
CHECK (status == OK);
CHECK (d->used == 4);
CHECK (d->keys->usable == usable - d->used);
// insert 5th key
status = numba_dict_insert(d, "bej", 0xcafe, "0_0_0_2", got_value);
CHECK (status == OK);
CHECK (d->used == 5);
CHECK (d->keys->usable == usable - d->used);
// insert 6th key & triggers resize
status = numba_dict_insert(d, "bek", 0xcafe, "0_0_0_3", got_value);
CHECK (status == OK);
CHECK (d->used == 6);
CHECK (d->keys->usable == USABLE_FRACTION(d->keys->size) - d->used);
// Dump
numba_dict_dump(d);
// Make sure everything are still in there
ix = numba_dict_lookup(d, "bef", 0xbeef, got_value);
CHECK (ix >= 0);
CHECK (memcpy(got_value, "7654321", d->keys->val_size));
ix = numba_dict_lookup(d, "beg", 0xbeef, got_value);
CHECK (ix >= 0);
CHECK (memcpy(got_value, "1234567", d->keys->val_size));
ix = numba_dict_lookup(d, "beh", 0xcafe, got_value);
printf("ix = %d\n", (int)ix);
CHECK (ix >= 0);
CHECK (memcpy(got_value, "1234569", d->keys->val_size));
ix = numba_dict_lookup(d, "bei", 0xcafe, got_value);
CHECK (ix >= 0);
CHECK (memcpy(got_value, "0_0_0_1", d->keys->val_size));
ix = numba_dict_lookup(d, "bej", 0xcafe, got_value);
CHECK (ix >= 0);
CHECK (memcpy(got_value, "0_0_0_2", d->keys->val_size));
ix = numba_dict_lookup(d, "bek", 0xcafe, got_value);
CHECK (ix >= 0);
CHECK (memcpy(got_value, "0_0_0_3", d->keys->val_size));
// Test delete
ix = numba_dict_lookup(d, "beg", 0xbeef, got_value);
status = numba_dict_delitem(d, 0xbeef, ix);
CHECK (status == OK);
ix = numba_dict_lookup(d, "beg", 0xbeef, got_value);
CHECK (ix == DKIX_EMPTY); // not found
ix = numba_dict_lookup(d, "bef", 0xbeef, got_value);
CHECK (ix >= 0);
ix = numba_dict_lookup(d, "beh", 0xcafe, got_value);
CHECK (ix >= 0);
// Test popitem
// They are always the last item
status = numba_dict_popitem(d, got_key, got_value);
CHECK(status == OK);
CHECK(memcmp("bek", got_key, d->keys->key_size) == 0);
CHECK(memcmp("0_0_0_3", got_value, d->keys->val_size) == 0);
status = numba_dict_popitem(d, got_key, got_value);
CHECK(status == OK);
CHECK(memcmp("bej", got_key, d->keys->key_size) == 0);
CHECK(memcmp("0_0_0_2", got_value, d->keys->val_size) == 0);
// Test iterator
CHECK( d->used > 0 );
numba_dict_iter(&iter, d);
it_count = 0;
while ( (status = numba_dict_iter_next(&iter, &it_key, &it_val)) == OK) {
it_count += 1; // valid items
CHECK(it_key != NULL);
CHECK(it_val != NULL);
}
CHECK(status == ERR_ITER_EXHAUSTED);
CHECK(d->used == it_count);
numba_dict_free(d);
/* numba_dict_new_sized() */
Py_ssize_t target_size;
Py_ssize_t n_keys;
// Test if minsize dict returned with n_keys=0
target_size = D_MINSIZE;
n_keys = 0;
numba_dict_new_sized(&d, n_keys, 1, 1);
CHECK(d->keys->size == target_size);
CHECK(d->keys->usable == USABLE_FRACTION(target_size));
numba_dict_free(d);
// Test sizing at power of 2 boundary
target_size = D_MINSIZE * 2;
n_keys = USABLE_FRACTION(target_size);
numba_dict_new_sized(&d, n_keys, 1, 1);
CHECK(d->keys->size == target_size);
CHECK(d->keys->usable == n_keys);
numba_dict_free(d);
target_size *= 2;
n_keys++;
numba_dict_new_sized(&d, n_keys, 1, 1);
CHECK(d->keys->size == target_size);
CHECK(d->keys->usable > n_keys);
CHECK(d->keys->usable == USABLE_FRACTION(target_size));
numba_dict_free(d);
return 0;
}
#undef CHECK
/* Adapted from CPython3.7 Objects/dict-common.h */
#include "cext.h"
#ifndef NUMBA_DICT_COMMON_H
#define NUMBA_DICT_COMMON_H
typedef struct {
/* Uses Py_ssize_t instead of Py_hash_t to guarantee word size alignment */
Py_ssize_t hash;
char keyvalue[];
} NB_DictEntry;
typedef int (*dict_key_comparator_t)(const char *lhs, const char *rhs);
typedef void (*dict_refcount_op_t)(const void*);
typedef struct {
dict_key_comparator_t key_equal;
dict_refcount_op_t key_incref;
dict_refcount_op_t key_decref;
dict_refcount_op_t value_incref;
dict_refcount_op_t value_decref;
} type_based_methods_table;
typedef struct {
/* hash table size */
Py_ssize_t size;
/* Usable size of the hash table.
Also, size of the entries */
Py_ssize_t usable;
/* hash table used entries */
Py_ssize_t nentries;
/* Entry info
- key_size is the sizeof key type
- val_size is the sizeof value type
- entry_size is key_size + val_size + alignment
*/
Py_ssize_t key_size, val_size, entry_size;
/* Byte offset from indices to the first entry. */
Py_ssize_t entry_offset;
/* Method table for type-dependent operations. */
type_based_methods_table methods;
/* hash table */
char indices[];
} NB_DictKeys;
typedef struct {
/* num of elements in the hashtable */
Py_ssize_t used;
NB_DictKeys *keys;
} NB_Dict;
typedef struct {
/* parent dictionary */
NB_Dict *parent;
/* parent keys object */
NB_DictKeys *parent_keys;
/* dict size */
Py_ssize_t size;
/* iterator position; indicates the next position to read */
Py_ssize_t pos;
} NB_DictIter;
/* A test function for the dict
Returns 0 for OK; 1 for failure.
*/
NUMBA_EXPORT_FUNC(int)
numba_test_dict(void);
/* Allocate a new dict
Parameters
- NB_Dict **out
Output for the new dictionary.
- Py_ssize_t size
Hashtable size. Must be power of two.
- Py_ssize_t key_size
Size of a key entry.
- Py_ssize_t val_size
Size of a value entry.
*/
NUMBA_EXPORT_FUNC(int)
numba_dict_new(NB_Dict **out, Py_ssize_t size, Py_ssize_t key_size, Py_ssize_t val_size);
/* Allocate a new dict with enough space to hold n_keys without resizing.
Parameters
- NB_Dict **out
Output for the new dictionary.
- Py_ssize_t n_keys
The number of keys to fit without needing resize.
- Py_ssize_t key_size
Size of a key entry.
- Py_ssize_t val_size
Size of a value entry.
*/
NUMBA_EXPORT_FUNC(int)
numba_dict_new_sized(NB_Dict** out, Py_ssize_t n_keys, Py_ssize_t key_size, Py_ssize_t val_size);
/* Free a dict */
NUMBA_EXPORT_FUNC(void)
numba_dict_free(NB_Dict *d);
/* Returns length of a dict */
NUMBA_EXPORT_FUNC(Py_ssize_t)
numba_dict_length(NB_Dict *d);
/* Set the method table for type specific operations
*/
NUMBA_EXPORT_FUNC(void)
numba_dict_set_method_table(NB_Dict *d, type_based_methods_table *methods);
/* Lookup a key
Parameters
- NB_Dict *d
The dictionary object.
- const char *key_bytes
The key as a byte buffer.
- Py_hash_t hash
The precomputed hash of the key.
- char *oldval_bytes
An output parameter to store the associated value if the key is found.
Must point to memory of sufficient size to store the value.
*/
NUMBA_EXPORT_FUNC(Py_ssize_t)
numba_dict_lookup(NB_Dict *d, const char *key_bytes, Py_hash_t hash, char *oldval_bytes);
/* Resize the dict to at least *minsize*.
*/
NUMBA_EXPORT_FUNC(int)
numba_dict_resize(NB_Dict *d, Py_ssize_t minsize);
/* Insert to the dict
Parameters
- NB_Dict *d
The dictionary object.
- const char *key_bytes
The key as a byte buffer.
- Py_hash_t hash
The precomputed hash of key.
- const char *val_bytes
The value as a byte buffer.
- char *oldval_bytes
An output buffer to store the replaced value.
Must point to memory of sufficient size to store the value.
Returns
- < 0 for error
- 0 for ok
- 1 for ok and oldval_bytes has a copy of the replaced value.
*/
NUMBA_EXPORT_FUNC(int)
numba_dict_insert(NB_Dict *d, const char *key_bytes, Py_hash_t hash, const char *val_bytes, char *oldval_bytes);
/* Same as numba_dict_insert() but oldval_bytes is not needed */
NUMBA_EXPORT_FUNC(int)
numba_dict_insert_ez(NB_Dict *d, const char *key_bytes, Py_hash_t hash, const char *val_bytes);
/* Delete an entry from the dict
Parameters
- NB_Dict *d
The dictionary
- Py_hash_t hash
Precomputed hash of the key to be deleted
- Py_ssize_t ix
Precomputed entry index of the key to be deleted.
Usually results of numba_dict_lookup().
*/
NUMBA_EXPORT_FUNC(int)
numba_dict_delitem(NB_Dict *d, Py_hash_t hash, Py_ssize_t ix);
/* Remove an item from the dict
Parameters
- NB_Dict *d
The dictionary
- char *key_bytes
Output. The key as a byte buffer
- char *val_bytes
Output. The value as a byte buffer
*/
NUMBA_EXPORT_FUNC(int)
numba_dict_popitem(NB_Dict *d, char *key_bytes, char *val_bytes);
/* Returns the sizeof a dictionary iterator
*/
NUMBA_EXPORT_FUNC(size_t)
numba_dict_iter_sizeof(void);
/* Fill a NB_DictIter for a dictionary to begin iteration
Parameters
- NB_DictIter *it
Output. Must points to memory of size at least `numba_dict_iter_sizeof()`.
- NB_Dict *d
The dictionary to be iterated.
*/
NUMBA_EXPORT_FUNC(void)
numba_dict_iter(NB_DictIter *it, NB_Dict *d);
/* Advance the iterator
Parameters
- NB_DictIter *it
The iterator
- const char **key_ptr
Output pointer for the key. Points to data in the dictionary.
- const char **val_ptr
Output pointer for the key. Points to data in the dictionary.
Returns
- 0 for success; valid key_ptr and val_ptr
- ERR_ITER_EXHAUSTED for end of iterator.
- ERR_DICT_MUTATED for detected dictionary mutation.
*/
NUMBA_EXPORT_FUNC(int)
numba_dict_iter_next(NB_DictIter *it, const char **key_ptr, const char **val_ptr);
NUMBA_EXPORT_FUNC(void)
numba_dict_dump(NB_Dict *);
#endif
#include "listobject.h"
/* This implements the C component of the Numba typed list. It is loosely
* inspired by the list implementation of the cpython list with some parts
* taken from the cpython slice implementation. The exact commit-id of the
* relevant files are:
*
* https://github.com/python/cpython/blob/51ddab8dae056867f3595ab3400bffc93f67c8d4/Objects/listobject.c
* https://github.com/python/cpython/blob/51ddab8dae056867f3595ab3400bffc93f67c8d4/Objects/sliceobject.c
*
* Algorithmically, this list is very similar to the cpython implementation so
* it should have the same performance (Big-O) characteristics for accessing,
* adding and removing elements/items. Specifically, it implements the same
* algorithms for list overallocation and growth. However, it never deals with
* PyObject types and instead must be typed with a type-size. As a result, the
* typed-list is type homogeneous and in contrast to the cpython version can
* not store a mixture of arbitrarily typed objects. Reference counting via the
* Numba Runtime (NRT) is supported and incrementing and decrementing functions
* are store as part of the struct and can be setup from the compiler level.
*
* Importantly, only a very limited subset of the cpython c functions have been
* ported over and the rest have been implemented (in Python) at the compiler
* level using the c functions provided. Additionally, initialization of, and
* iteration over, a ListIter is provided
*
* The following functions are implemented for the list:
*
* - Check valid index valid_index
* - Creation numba_list_new
* - Deletion numba_list_free
* - Accessing the length numba_list_length
* - Appending to the list numba_list_append
* - Getting an item numba_list_setitem
* - Setting an item numba_list_getitem
* - Resizing the list numba_list_resize
* - Deleting an item numba_list_delitem
* - Deleting a slice numba_list_delete_slice
*
* As you can see, only a single function for slices is implemented. The rest
* is all done entirely at the compiler level which then calls the c functions
* to mutate the list accordingly. Since slicing allows for replace, insert and
* delete operations over multiple items, we can simply implement those using
* the basic functions above.
*
* The following additional functions are implemented for the list, these are
* needed to make the list work within Numba.
*
* - Accessing the allocation numba_list_allocated
* - Copying an item copy_item
* - Calling incref on item list_incref_item
* - Calling decref on item list_decref_item
* - Set method table numba_list_set_method_table
*
* The following functions are implemented for the iterator:
*
* - Size of the iterator numba_list_iter_size
* - Initialization of iter numba_list_iter
* - Get next item from iter numba_list_iter_next
*
* Two methods are provided to query and set the 'is_mutable':
*
* - Query numba_list_is_mutable
* - Set numba_list_set_is_mutable
*
* Lastly a set of pure C level tests are provided which come in handy when
* needing to use valgrind and friends.
*
*/
/* Return status for the list functions.
*/
typedef enum {
LIST_OK = 0,
LIST_ERR_INDEX = -1,
LIST_ERR_NO_MEMORY = -2,
LIST_ERR_MUTATED = -3,
LIST_ERR_ITER_EXHAUSTED = -4,
LIST_ERR_IMMUTABLE = -5,
} ListStatus;
/* Copy an item from a list.
*
* lp: a list
* dst: destination pointer
* src: source pointer
*/
static void
copy_item(NB_List *lp, char *dst, const char *src){
memcpy(dst, src, lp->item_size);
}
/* Increment a reference to an item in a list.
*
* lp: a list
* item: the item to increment the reference for
*/
static void
list_incref_item(NB_List *lp, const char *item){
if (lp->methods.item_incref) {
lp->methods.item_incref(item);
}
}
/* Decrement a reference to an item in a list.
*
* lp: a list
* item: the item to decrement the reference for
*/
static void
list_decref_item(NB_List *lp, const char *item){
if (lp->methods.item_decref) {
lp->methods.item_decref(item);
}
}
/* Setup the method table for a list.
*
* This function is used from the compiler level to initialize the internal
* method table.
*
* lp: a list
* methods: the methods table to set up
*/
void
numba_list_set_method_table(NB_List *lp, list_type_based_methods_table *methods)
{
memcpy(&lp->methods, methods, sizeof(list_type_based_methods_table));
}
/* Check if a list index is valid.
*
* i: the index to check
* limit: the size of a list
*
* Adapted from CPython's valid_index().
*
* FIXME: need to find a way to inline this, even for Python 2.7 on Windows
*/
static int
valid_index(Py_ssize_t i, Py_ssize_t limit){
/* The cast to size_t lets us use just a single comparison
to check whether i is in the range: 0 <= i < limit.
See: Section 14.2 "Bounds Checking" in the Agner Fog
optimization manual found at:
https://www.agner.org/optimize/optimizing_cpp.pdf
*/
return (size_t) i < (size_t) limit;
}
/* Initialize a new list.
*
* out: pointer to hold an initialized list
* item_size: the size in bytes of the items in the list
* allocated: preallocation of the list in items
*
* This will allocate sufficient memory to hold the list structure and any
* items if requested (allocated != 0). See _listobject.h for more information
* on the NB_List struct.
*/
int
numba_list_new(NB_List **out, Py_ssize_t item_size, Py_ssize_t allocated){
NB_List *lp;
char *items;
// allocate memory to hold the struct
lp = malloc(aligned_size(sizeof(NB_List)));
if (lp == NULL) {
return LIST_ERR_NO_MEMORY;
}
// set up members
lp->size = 0;
lp->item_size = item_size;
lp->allocated = allocated;
lp->is_mutable = 1;
// set method table to zero */
memset(&lp->methods, 0x00, sizeof(list_type_based_methods_table));
// allocate memory to hold items, if requested
if (allocated != 0) {
items = malloc(aligned_size(lp->item_size * allocated));
// allocated was definitely not zero, if malloc returns NULL
// this is definitely an error
if (items == NULL) {
// free previously allocated struct to avoid leaking memory
free(lp);
return LIST_ERR_NO_MEMORY;
}
lp->items = items;
}
else {
// be explicit
lp->items = NULL;
}
*out = lp;
return LIST_OK;
}
/* Free the memory associated with a list.
*
* lp: a list
*/
void
numba_list_free(NB_List *lp) {
// decref all items, if needed
Py_ssize_t i;
if (lp->methods.item_decref) {
for (i = 0; i < lp->size; i++) {
char *item = lp->items + lp->item_size * i;
list_decref_item(lp, item);
}
}
// free items and list
if (lp->items != NULL) {
free(lp->items);
}
free(lp);
}
/* Return the base pointer of the list items.
*/
char *
numba_list_base_ptr(NB_List *lp)
{
return lp->items;
}
/* Return the address of the list size.
*/
Py_ssize_t
numba_list_size_address(NB_List *lp)
{
return (Py_ssize_t)&lp->size;
}
/* Return the length of a list.
*
* lp: a list
*/
Py_ssize_t
numba_list_length(NB_List *lp) {
return lp->size;
}
/* Return the current allocation of a list.
*
* lp: a list
*/
Py_ssize_t
numba_list_allocated(NB_List *lp) {
return lp->allocated;
}
/* Return the mutability status of the list
*
* lp: a list
*
*/
int
numba_list_is_mutable(NB_List *lp){
return lp->is_mutable;
}
/* Set the is_mutable attribute
*
* lp: a list
* is_mutable: an int, 0(False) or 1(True)
*
*/
void
numba_list_set_is_mutable(NB_List *lp, int is_mutable){
lp->is_mutable = is_mutable;
}
/* Set an item in a list.
*
* lp: a list
* index: the index of the item to set (must be in range 0 <= index < len(list))
* item: the item to set
*
* This assume there is already an element at the given index that will be
* overwritten and thereby have its reference decremented. DO NOT use this to
* write to an unassigned location.
*/
int
numba_list_setitem(NB_List *lp, Py_ssize_t index, const char *item) {
char *loc;
// check for mutability
if (!lp->is_mutable) {
return LIST_ERR_IMMUTABLE;
}
// check index is valid
// FIXME: this can be (and probably is) checked at the compiler level
if (!valid_index(index, lp->size)) {
return LIST_ERR_INDEX;
}
// set item at desired location
loc = lp->items + lp-> item_size * index;
list_decref_item(lp, loc);
copy_item(lp, loc, item);
list_incref_item(lp, loc);
return LIST_OK;
}
/* Get an item from a list.
*
* lp: a list
* index: the index of the item to get (must be in range 0 <= index < len(list))
* out: a pointer to hold the item
*/
int
numba_list_getitem(NB_List *lp, Py_ssize_t index, char *out) {
char *loc;
// check index is valid
// FIXME: this can be (and probably is) checked at the compiler level
if (!valid_index(index, lp->size)) {
return LIST_ERR_INDEX;
}
// get item at desired location
loc = lp->items + lp->item_size * index;
copy_item(lp, out, loc);
return LIST_OK;
}
/* Append an item to the end of a list.
*
* lp: a list
* item: the item to append.
*/
int
numba_list_append(NB_List *lp, const char *item) {
char *loc;
// check for mutability
if (!lp->is_mutable) {
return LIST_ERR_IMMUTABLE;
}
// resize by one, will change list size
int result = numba_list_resize(lp, lp->size + 1);
if(result < LIST_OK) {
return result;
}
// insert item at index: original size before resize
loc = lp->items + lp->item_size * (lp->size - 1);
copy_item(lp, loc, item);
list_incref_item(lp, loc);
return LIST_OK;
}
/* Resize a list.
*
* lp: a list
* newsize: the desired new size of the list.
*
* This will increase or decrease the size of the list, including reallocating
* the required memory and increasing the total allocation (additional free
* space to hold new items).
*
*
* Adapted from CPython's list_resize().
*
* Ensure lp->items has room for at least newsize elements, and set
* lp->size to newsize. If newsize > lp->size on entry, the content
* of the new slots at exit is undefined heap trash; it's the caller's
* responsibility to overwrite them with sane values.
* The number of allocated elements may grow, shrink, or stay the same.
* Failure is impossible if newsize <= lp->allocated on entry, although
* that partly relies on an assumption that the system realloc() never
* fails when passed a number of bytes <= the number of bytes last
* allocated (the C standard doesn't guarantee this, but it's hard to
* imagine a realloc implementation where it wouldn't be true).
* Note that lp->items may change, and even if newsize is less
* than lp->size on entry.
*/
int
numba_list_resize(NB_List *lp, Py_ssize_t newsize) {
char * items;
// check for mutability
if (!lp->is_mutable) {
return LIST_ERR_IMMUTABLE;
}
size_t new_allocated, num_allocated_bytes;
/* Bypass realloc() when a previous overallocation is large enough
to accommodate the newsize. If the newsize falls lower than half
the allocated size, then proceed with the realloc() to shrink the list.
*/
if (lp->allocated >= newsize && newsize >= (lp->allocated >> 1)) {
assert(lp->items != NULL || newsize == 0);
lp->size = newsize;
return LIST_OK;
}
/* This over-allocates proportional to the list size, making room
* for additional growth. The over-allocation is mild, but is
* enough to give linear-time amortized behavior over a long
* sequence of appends() in the presence of a poorly-performing
* system realloc().
* The growth pattern is: 0, 4, 8, 16, 25, 35, 46, 58, 72, 88, ...
* Note: new_allocated won't overflow because the largest possible value
* is PY_SSIZE_T_MAX * (9 / 8) + 6 which always fits in a size_t.
*/
new_allocated = (size_t)newsize + (newsize >> 3) + (newsize < 9 ? 3 : 6);
if (new_allocated > (size_t)PY_SSIZE_T_MAX / lp->item_size) {
return LIST_ERR_NO_MEMORY;
}
if (newsize == 0)
new_allocated = 0;
num_allocated_bytes = new_allocated * lp->item_size;
items = realloc(lp->items, aligned_size(num_allocated_bytes));
// realloc may return NULL if requested size is 0
if (num_allocated_bytes != 0 && items == NULL) {
return LIST_ERR_NO_MEMORY;
}
lp->items = items;
lp->size = newsize;
lp->allocated = (Py_ssize_t)new_allocated;
return LIST_OK;
}
/* Delete a single item.
*
* lp: a list
* index: the index of the item to delete
* (must be in range 0 <= index < len(list))
*
* */
int
numba_list_delitem(NB_List *lp, Py_ssize_t index) {
int result;
char *loc, *new_loc;
Py_ssize_t leftover_bytes;
// check for mutability
if (!lp->is_mutable) {
return LIST_ERR_IMMUTABLE;
}
// check index is valid
// FIXME: this can be (and probably is) checked at the compiler level
if (!valid_index(index, lp->size)) {
return LIST_ERR_INDEX;
}
// obtain item and decref if needed
loc = lp->items + lp->item_size * index;
list_decref_item(lp, loc);
if (index != lp->size - 1) {
// delitem from somewhere other than the end, incur the memory copy
leftover_bytes = (lp->size - 1 - index) * lp->item_size;
new_loc = lp->items + (lp->item_size * (index + 1));
// use memmove instead of memcpy since we may be dealing with
// overlapping regions of memory and the behaviour of memcpy is
// undefined in such situation (C99).
memmove(loc, new_loc, leftover_bytes);
}
// finally, shrink list by one
result = numba_list_resize(lp, lp->size - 1);
if(result < LIST_OK) {
// Since we are decreasing the size, this should never happen
return result;
}
return LIST_OK;
}
/* Delete a slice
*
* start: the start index of ths slice
* stop: the stop index of the slice (not included)
* step: the step to take
*
* This function assumes that the start and stop were clipped appropriately.
* I.e. if step > 0 start >= 0 and stop <= len(l) and
* if step < 0 start <= length and stop >= -1
* step != 0 and no Python negative indexing allowed.
*
* This code was copied and edited from the relevant section in
* list_ass_subscript from the cpython implementation, see the top of this file
* for the exact source
*/
int
numba_list_delete_slice(NB_List *lp,
Py_ssize_t start, Py_ssize_t stop, Py_ssize_t step) {
int result, i, slicelength, new_length;
char *loc, *new_loc;
Py_ssize_t leftover_bytes, cur, lim;
// check for mutability
if (!lp->is_mutable) {
return LIST_ERR_IMMUTABLE;
}
// calculate the slicelength, taken from PySlice_AdjustIndices, see the top
// of this file for the exact source
if (step > 0) {
slicelength = start < stop ? (stop - start - 1) / step + 1 : 0;
} else {
slicelength = stop < start ? (start - stop - 1) / -step + 1 : 0;
}
if (slicelength <= 0){
return LIST_OK;
}
new_length = lp->size - slicelength;
// reverse step and indices
if (step < 0) {
stop = start + 1;
start = stop + step * (slicelength - 1) - 1;
step = -step;
}
if (step == 1) {
// decref if needed
if (lp->methods.item_decref) {
for (i = start ; i < stop ; i++){
loc = lp->items + lp->item_size * i;
lp->methods.item_decref(loc);
}
}
// memmove items into place
leftover_bytes = (lp->size - stop) * lp->item_size;
loc = lp->items + lp->item_size * start;
new_loc = lp->items + lp->item_size * stop;
memmove(loc, new_loc, leftover_bytes);
}
else { // step != 1
/* drawing pictures might help understand these for
* loops. Basically, we memmove the parts of the
* list that are *not* part of the slice: step-1
* items for each item that is part of the slice,
* and then tail end of the list that was not
* covered by the slice
*
* */
for (cur = start, // index of item to be deleted
i = 0; // counter of total items deleted so far
cur < stop;
cur += step,
i++) {
lim = step - 1; // number of leftover items after deletion of item
// clip limit, in case we are at the end of the slice, and there
// are now less than step-1 items to be moved
if (cur + step >= lp->size) {
lim = lp->size - cur - 1;
}
// decref item being removed
loc = lp->items + lp->item_size * cur;
list_decref_item(lp, loc);
/* memmove the aforementioned step-1 (or less) items
* dst : index of deleted item minus total deleted sofar
* src : index of deleted item plus one (next item)
*/
memmove(lp->items + lp->item_size * (cur - i),
lp->items + lp->item_size * (cur + 1),
lim * lp->item_size);
}
// memmove tail of the list
cur = start + slicelength * step;
if (cur < lp->size) {
memmove(lp->items + lp->item_size * (cur - slicelength),
lp->items + lp->item_size * cur,
(lp->size - cur) * lp->item_size);
}
}
// resize to correct size
result = numba_list_resize(lp, new_length);
if(result < LIST_OK) {
// Since we are decreasing the size, this should never happen
return result;
}
return LIST_OK;
}
/* Return the size of the list iterator (NB_ListIter) struct.
*/
size_t
numba_list_iter_sizeof() {
return sizeof(NB_ListIter);
}
/* Initialize a list iterator (NB_ListIter).
*
* it: an iterator
* lp: a list to iterate over
*/
void
numba_list_iter(NB_ListIter *it, NB_List *lp) {
// set members of iterator
it->parent = lp;
it->size = lp->size;
it->pos = 0;
}
/* Obtain the next item from a list iterator.
*
* it: an iterator
* item_ptr: pointer to hold the next item
*/
int
numba_list_iter_next(NB_ListIter *it, const char **item_ptr) {
NB_List *lp;
lp = it->parent;
/* FIXME: Detect list mutation during iteration */
if (lp->size != it->size) {
return LIST_ERR_MUTATED;
}
// get next element
if (it->pos < lp->size) {
*item_ptr = lp->items + lp->item_size * it->pos++;
return LIST_OK;
}else{
return LIST_ERR_ITER_EXHAUSTED;
}
}
#define CHECK(CASE) { \
if ( !(CASE) ) { \
printf("'%s' failed file %s:%d\n", #CASE, __FILE__, __LINE__); \
return -1; \
} \
}
/* Basic C based tests for the list.
*/
int
numba_test_list(void) {
NB_List *lp = NULL;
int status, i;
Py_ssize_t it_count;
const char *it_item = NULL;
NB_ListIter iter;
char got_item[4] = "\x00\x00\x00\x00";
const char *test_items_1 = NULL, *test_items_2 = NULL;
char *test_items_3 = NULL;
puts("test_list");
status = numba_list_new(&lp, 4, 0);
CHECK(status == LIST_OK);
CHECK(lp->item_size == 4);
CHECK(lp->size == 0);
CHECK(lp->allocated == 0);
CHECK(lp->is_mutable == 1);
// flip and check the is_mutable bit
CHECK(numba_list_is_mutable(lp) == 1);
numba_list_set_is_mutable(lp, 0);
CHECK(numba_list_is_mutable(lp) == 0);
numba_list_set_is_mutable(lp, 1);
CHECK(numba_list_is_mutable(lp) == 1);
// append 1st item, this will cause a realloc
status = numba_list_append(lp, "abc");
CHECK(status == LIST_OK);
CHECK(lp->size == 1);
CHECK(lp->allocated == 4);
status = numba_list_getitem(lp, 0, got_item);
CHECK(status == LIST_OK);
CHECK(memcmp(got_item, "abc", 4) == 0);
// append 2nd item
status = numba_list_append(lp, "def");
CHECK(status == LIST_OK);
CHECK(lp->size == 2);
CHECK(lp->allocated == 4);
status = numba_list_getitem(lp, 1, got_item);
CHECK(status == LIST_OK);
CHECK(memcmp(got_item, "def", 4) == 0);
// append 3rd item
status = numba_list_append(lp, "ghi");
CHECK(status == LIST_OK);
CHECK(lp->size == 3);
CHECK(lp->allocated == 4);
status = numba_list_getitem(lp, 2, got_item);
CHECK(status == LIST_OK);
CHECK(memcmp(got_item, "ghi", 4) == 0);
// append 4th item
status = numba_list_append(lp, "jkl");
CHECK(status == LIST_OK);
CHECK(lp->size == 4);
CHECK(lp->allocated == 4);
status = numba_list_getitem(lp, 3, got_item);
CHECK(status == LIST_OK);
CHECK(memcmp(got_item, "jkl", 4) == 0);
// append 5th item, this will cause another realloc
status = numba_list_append(lp, "mno");
CHECK(status == LIST_OK);
CHECK(lp->size == 5);
CHECK(lp->allocated == 8);
status = numba_list_getitem(lp, 4, got_item);
CHECK(status == LIST_OK);
CHECK(memcmp(got_item, "mno", 4) == 0);
// overwrite 1st item
status = numba_list_setitem(lp, 0, "pqr");
CHECK(status == LIST_OK);
CHECK(lp->size == 5);
CHECK(lp->allocated == 8);
status = numba_list_getitem(lp, 0, got_item);
CHECK(status == LIST_OK);
CHECK(memcmp(got_item, "pqr", 4) == 0);
// get and del 1st item, check item shift
status = numba_list_getitem(lp, 0, got_item);
status = numba_list_delitem(lp, 0);
CHECK(status == LIST_OK);
CHECK(lp->size == 4);
CHECK(lp->allocated == 8);
CHECK(memcmp(got_item, "pqr", 4) == 0);
CHECK(memcmp(lp->items, "def\x00ghi\x00jkl\x00mno\x00", 16) == 0);
// get and del last (4th) item, no shift since only last item affected
status = numba_list_getitem(lp, 3, got_item);
status = numba_list_delitem(lp, 3);
CHECK(status == LIST_OK);
CHECK(lp->size == 3);
CHECK(lp->allocated == 6); // this also shrinks the allocation
CHECK(memcmp(got_item, "mno", 4) == 0);
CHECK(memcmp(lp->items, "def\x00ghi\x00jkl\x00", 12) == 0);
// flip and check the is_mutable member
CHECK(numba_list_is_mutable(lp) == 1);
numba_list_set_is_mutable(lp, 0);
CHECK(numba_list_is_mutable(lp) == 0);
// ensure that any attempts to mutate an immutable list fail
CHECK(numba_list_setitem(lp, 0, "zzz") == LIST_ERR_IMMUTABLE);
CHECK(numba_list_append(lp, "zzz") == LIST_ERR_IMMUTABLE);
CHECK(numba_list_delitem(lp, 0) == LIST_ERR_IMMUTABLE);
CHECK(numba_list_resize(lp, 23) == LIST_ERR_IMMUTABLE);
CHECK(numba_list_delete_slice(lp, 0, 3, 1) == LIST_ERR_IMMUTABLE);
// ensure that all attempts to query/read from and immutable list succeed
CHECK(numba_list_length(lp) == 3);
status = numba_list_getitem(lp, 0, got_item);
CHECK(status == LIST_OK);
CHECK(memcmp(got_item, "def", 4) == 0);
// flip the is_mutable member back and check
numba_list_set_is_mutable(lp, 1);
CHECK(numba_list_is_mutable(lp) == 1);
// test iterator
CHECK(lp->size > 0);
numba_list_iter(&iter, lp);
it_count = 0;
CHECK(iter.parent == lp);
CHECK(iter.pos == it_count);
// current contents of list
test_items_1 = "def\x00ghi\x00jkl\x00";
while ( (status = numba_list_iter_next(&iter, &it_item)) == LIST_OK) {
it_count += 1;
CHECK(iter.pos == it_count); // check iterator position
CHECK(it_item != NULL); // quick check item is non-null
// go fishing in test_items_1
CHECK(memcmp((const char *)test_items_1 + ((it_count - 1) * 4), it_item, 4) == 0);
}
CHECK(status == LIST_ERR_ITER_EXHAUSTED);
CHECK(lp->size == it_count);
// free existing list
numba_list_free(lp);
// test growth upon append and shrink during delitem
status = numba_list_new(&lp, 1, 0);
CHECK(status == LIST_OK);
CHECK(lp->item_size == 1);
CHECK(lp->size == 0);
CHECK(lp->allocated == 0);
// first, grow the list
// Use exactly 17 elements, should go through the allocation pattern:
// 0, 4, 8, 16, 25
for (i = 0; i < 17 ; i++) {
switch(i) {
// Check the allocation before
case 0: CHECK(lp->allocated == 0); break;
case 4: CHECK(lp->allocated == 4); break;
case 8: CHECK(lp->allocated == 8); break;
case 16: CHECK(lp->allocated == 16); break;
}
status = numba_list_append(lp, (const char*)&i);
CHECK(status == LIST_OK);
switch(i) {
// Check that the growth happened accordingly
case 0: CHECK(lp->allocated == 4); break;
case 4: CHECK(lp->allocated == 8); break;
case 8: CHECK(lp->allocated == 16); break;
case 16: CHECK(lp->allocated == 25); break;
}
}
CHECK(lp->size == 17);
// Check current contents of list
test_items_2 = "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10";
CHECK(memcmp(lp->items, test_items_2, 17) == 0);
// Now, delete them again and check that list shrinks
for (i = 17; i > 0 ; i--) {
switch(i) {
// Check the allocation before delitem
case 17: CHECK(lp->allocated == 25); break;
case 12: CHECK(lp->allocated == 25); break;
case 9: CHECK(lp->allocated == 18); break;
case 6: CHECK(lp->allocated == 12); break;
case 4: CHECK(lp->allocated == 8); break;
case 3: CHECK(lp->allocated == 6); break;
case 2: CHECK(lp->allocated == 5); break;
case 1: CHECK(lp->allocated == 4); break;
}
status = numba_list_getitem(lp, i-1, got_item);
status = numba_list_delitem(lp, i-1);
CHECK(status == LIST_OK);
switch(i) {
// Check that the shrink happened accordingly
case 17: CHECK(lp->allocated == 25); break;
case 12: CHECK(lp->allocated == 18); break;
case 9: CHECK(lp->allocated == 12); break;
case 6: CHECK(lp->allocated == 8); break;
case 4: CHECK(lp->allocated == 6); break;
case 3: CHECK(lp->allocated == 5); break;
case 2: CHECK(lp->allocated == 4); break;
case 1: CHECK(lp->allocated == 0); break;
}
}
// free existing list
numba_list_free(lp);
// Setup list for testing delete_slice
status = numba_list_new(&lp, 1, 0);
CHECK(status == LIST_OK);
CHECK(lp->item_size == 1);
CHECK(lp->size == 0);
CHECK(lp->allocated == 0);
for (i = 0; i < 17 ; i++) {
status = numba_list_append(lp, (const char*)&i);
CHECK(status == LIST_OK);
}
CHECK(lp->size == 17);
test_items_3 = "\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10";
CHECK(memcmp(lp->items, test_items_3, 17) == 0);
// delete multiple elements from the middle
status = numba_list_delete_slice(lp, 2, 5, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 14);
test_items_3 = "\x00\x01\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10";
CHECK(memcmp(lp->items, test_items_3, 14) == 0);
// delete single element from start
status = numba_list_delete_slice(lp, 0, 1, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 13);
test_items_3 = "\x01\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f\x10";
CHECK(memcmp(lp->items, test_items_3, 13) == 0);
// delete single element from end
status = numba_list_delete_slice(lp, 12, 13, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 12);
test_items_3 = "\x01\x05\x06\x07\x08\x09\x0a\x0b\x0c\x0d\x0e\x0f";
CHECK(memcmp(lp->items, test_items_3, 12) == 0);
// delete single element from middle
status = numba_list_delete_slice(lp, 4, 5, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 11);
test_items_3 = "\x01\x05\x06\x07\x09\x0a\x0b\x0c\x0d\x0e\x0f";
CHECK(memcmp(lp->items, test_items_3, 11) == 0);
// delete all elements except first and last
status = numba_list_delete_slice(lp, 1, 10, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 2);
test_items_3 = "\x01\x0f";
CHECK(memcmp(lp->items, test_items_3, 2) == 0);
// delete all remaining elements
status = numba_list_delete_slice(lp, 0, lp->size, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 0);
test_items_3 = "";
CHECK(memcmp(lp->items, test_items_3, 0) == 0);
// free existing list
numba_list_free(lp);
// Setup list for testing delete_slice with non unary step
status = numba_list_new(&lp, 1, 0);
CHECK(status == LIST_OK);
CHECK(lp->item_size == 1);
CHECK(lp->size == 0);
CHECK(lp->allocated == 0);
for (i = 0; i < 17 ; i++) {
status = numba_list_append(lp, (const char*)&i);
CHECK(status == LIST_OK);
}
CHECK(lp->size == 17);
// delete all items with odd index
status = numba_list_delete_slice(lp, 0, 17, 2);
CHECK(status == LIST_OK);
CHECK(lp->size == 8);
test_items_3 = "\x01\x03\x05\x07\x09\x0b\x0d\x0f";
CHECK(memcmp(lp->items, test_items_3, 8) == 0);
// delete with a step of 4, starting at index 1
status = numba_list_delete_slice(lp, 1, 8, 4);
CHECK(status == LIST_OK);
CHECK(lp->size == 6);
test_items_3 = "\x01\x05\x07\x09\x0d\x0f";
CHECK(memcmp(lp->items, test_items_3, 6) == 0);
// delete with a step of 2, but finish before end of list
status = numba_list_delete_slice(lp, 0, 4, 2);
CHECK(status == LIST_OK);
CHECK(lp->size == 4);
test_items_3 = "\x05\x09\x0d\x0f";
CHECK(memcmp(lp->items, test_items_3, 4) == 0);
// no-op on empty slice
status = numba_list_delete_slice(lp, 0, 0, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 4);
test_items_3 = "\x05\x09\x0d\x0f";
CHECK(memcmp(lp->items, test_items_3, 4) == 0);
// no-op on empty slice, non-zero index
status = numba_list_delete_slice(lp, 2, 2, 1);
CHECK(status == LIST_OK);
CHECK(lp->size == 4);
test_items_3 = "\x05\x09\x0d\x0f";
CHECK(memcmp(lp->items, test_items_3, 4) == 0);
// free list and return 0
numba_list_free(lp);
// Setup list for testing delete_slice with negative step
status = numba_list_new(&lp, 1, 0);
CHECK(status == LIST_OK);
CHECK(lp->item_size == 1);
CHECK(lp->size == 0);
CHECK(lp->allocated == 0);
for (i = 0; i < 17 ; i++) {
status = numba_list_append(lp, (const char*)&i);
CHECK(status == LIST_OK);
}
CHECK(lp->size == 17);
// delete all items using unary negative slice
status = numba_list_delete_slice(lp, 16, -1, -1);
CHECK(status == LIST_OK);
CHECK(lp->size == 0);
// refill list
for (i = 0; i < 17 ; i++) {
status = numba_list_append(lp, (const char*)&i);
CHECK(status == LIST_OK);
}
// delete all items using unary negative slice
// need to start at index of last item (16) and
// go beyond first item, i.e. -1 in Cd
status = numba_list_delete_slice(lp, 16, -1, -2);
CHECK(status == LIST_OK);
CHECK(lp->size == 8);
test_items_3 = "\x01\x03\x05\x07\x09\x0b\x0d\x0f";
CHECK(memcmp(lp->items, test_items_3, 8) == 0);
// free list and return 0
numba_list_free(lp);
return 0;
}
#undef CHECK
/* Adapted from CPython3.7 Include/listobject.h
*
* The exact commit-id of the relevant file is:
*
* https://github.com/python/cpython/blob/51ddab8dae056867f3595ab3400bffc93f67c8d4/Include/listobject.h
*
* WARNING:
* Most interfaces listed here are exported (global), but they are not
* supported, stable, or part of Numba's public API. These interfaces and their
* underlying implementations may be changed or removed in future without
* notice.
* */
#ifndef NUMBA_LIST_H
#define NUMBA_LIST_H
#include "cext.h"
typedef void (*list_refcount_op_t)(const void*);
typedef struct {
list_refcount_op_t item_incref;
list_refcount_op_t item_decref;
} list_type_based_methods_table;
/* This is the struct for the Numba typed list. It is largely inspired by the
* CPython list struct in listobject.h. In essence the list is a homogeneously
* typed container that can grow and shrink upon insertion and deletion. This
* means that appending an item to, or removing an item from, the end of the
* list, this will have a O(1) amortized runtime. This matches the
* behaviour of the CPython list type and it will grow with the same
* increments.
*
* 'items' contains space for 'allocated' elements. The number
* currently in use is 'size'. The size in bytes of the items stored in the
* list is given by 'item_size'.
*
* Invariants:
* 0 <= size <= allocated
* len(list) == size
* item == NULL implies size == allocated == 0
*
* FIXME: list.sort() temporarily sets allocated to -1 to detect mutations.
*
* Items must normally not be NULL, except during construction when
* the list is not yet visible outside the function that builds it.
*
* Additionally, this list has boolean member 'is_mutable' that can be used to
* set a list as immutable. Two functions to query and set this member are
* provided. Any attempt to mutate an immutable list will result in a status
* of LIST_ERR_IMMUTABLE.
*
*/
typedef struct {
/* size of the list in items */
Py_ssize_t size;
/* size of the list items in bytes */
Py_ssize_t item_size;
/* total allocated slots in items */
Py_ssize_t allocated;
/* is the list mutable */
int is_mutable;
/* method table for type-dependent operations */
list_type_based_methods_table methods;
/* array/pointer for items. Interpretation is governed by item_size */
char * items;
} NB_List;
typedef struct {
/* parent list */
NB_List *parent;
/* list size */
Py_ssize_t size;
/* iterator position; indicates the next position to read */
Py_ssize_t pos;
} NB_ListIter;
NUMBA_GLOBAL_FUNC(void)
numba_list_set_method_table(NB_List *lp, list_type_based_methods_table *methods);
NUMBA_GLOBAL_FUNC(int)
numba_list_new(NB_List **out, Py_ssize_t item_size, Py_ssize_t allocated);
NUMBA_GLOBAL_FUNC(void)
numba_list_free(NB_List *lp);
NUMBA_GLOBAL_FUNC(char *)
numba_list_base_ptr(NB_List *lp);
NUMBA_GLOBAL_FUNC(Py_ssize_t)
numba_list_size_address(NB_List *lp);
NUMBA_GLOBAL_FUNC(Py_ssize_t)
numba_list_length(NB_List *lp);
NUMBA_GLOBAL_FUNC(Py_ssize_t)
numba_list_allocated(NB_List *lp);
NUMBA_GLOBAL_FUNC(int)
numba_list_is_mutable(NB_List *lp);
NUMBA_GLOBAL_FUNC(void)
numba_list_set_is_mutable(NB_List *lp, int is_mutable);
NUMBA_GLOBAL_FUNC(int)
numba_list_setitem(NB_List *lp, Py_ssize_t index, const char *item);
NUMBA_GLOBAL_FUNC(int)
numba_list_getitem(NB_List *lp, Py_ssize_t index, char *out);
NUMBA_GLOBAL_FUNC(int)
numba_list_append(NB_List *lp, const char *item);
NUMBA_GLOBAL_FUNC(int)
numba_list_resize(NB_List *lp, Py_ssize_t newsize);
NUMBA_GLOBAL_FUNC(int)
numba_list_delitem(NB_List *lp, Py_ssize_t index);
NUMBA_GLOBAL_FUNC(int)
numba_list_delete_slice(NB_List *lp,
Py_ssize_t start, Py_ssize_t stop, Py_ssize_t step);
NUMBA_GLOBAL_FUNC(size_t)
numba_list_iter_sizeof(void);
NUMBA_GLOBAL_FUNC(void)
numba_list_iter(NB_ListIter *it, NB_List *l);
NUMBA_GLOBAL_FUNC(int)
numba_list_iter_next(NB_ListIter *it, const char **item_ptr);
NUMBA_EXPORT_FUNC(int)
numba_test_list(void);
#endif
#include "cext.h"
/* Align size *sz* to pointer width */
Py_ssize_t
aligned_size(Py_ssize_t sz) {
Py_ssize_t alignment = sizeof(void*);
return sz + (alignment - sz % alignment) % alignment;
}
from __future__ import absolute_import
# NOTE: The following imports are adapted to use as a vendored subpackage.
# from https://github.com/cloudpipe/cloudpickle/commit/f31859b1dd83fa691f4f7f797166b262c9acb8e7
from .cloudpickle import * # noqa
from .cloudpickle_fast import CloudPickler, dumps, dump # noqa
# Conform to the convention used by python serialization libraries, which
# expose their Pickler subclass at top-level under the "Pickler" name.
Pickler = CloudPickler
__version__ = '2.2.0'
"""
This class is defined to override standard pickle functionality
The goals of it follow:
-Serialize lambdas and nested functions to compiled byte code
-Deal with main module correctly
-Deal with other non-serializable objects
It does not include an unpickler, as standard python unpickling suffices.
This module was extracted from the `cloud` package, developed by `PiCloud, Inc.
<https://web.archive.org/web/20140626004012/http://www.picloud.com/>`_.
Copyright (c) 2012, Regents of the University of California.
Copyright (c) 2009 `PiCloud, Inc. <https://web.archive.org/web/20140626004012/http://www.picloud.com/>`_.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the University of California, Berkeley nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""
import builtins
import dis
import opcode
import platform
import sys
import types
import weakref
import uuid
import threading
import typing
import warnings
from .compat import pickle
from collections import OrderedDict
from typing import ClassVar, Generic, Union, Tuple, Callable
from pickle import _getattribute
from importlib._bootstrap import _find_spec
try: # pragma: no branch
import typing_extensions as _typing_extensions
from typing_extensions import Literal, Final
except ImportError:
_typing_extensions = Literal = Final = None
if sys.version_info >= (3, 8):
from types import CellType
else:
def f():
a = 1
def g():
return a
return g
CellType = type(f().__closure__[0])
# cloudpickle is meant for inter process communication: we expect all
# communicating processes to run the same Python version hence we favor
# communication speed over compatibility:
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
# Names of modules whose resources should be treated as dynamic.
_PICKLE_BY_VALUE_MODULES = set()
# Track the provenance of reconstructed dynamic classes to make it possible to
# reconstruct instances from the matching singleton class definition when
# appropriate and preserve the usual "isinstance" semantics of Python objects.
_DYNAMIC_CLASS_TRACKER_BY_CLASS = weakref.WeakKeyDictionary()
_DYNAMIC_CLASS_TRACKER_BY_ID = weakref.WeakValueDictionary()
_DYNAMIC_CLASS_TRACKER_LOCK = threading.Lock()
_DYNAMIC_CLASS_TRACKER_REUSING = weakref.WeakSet()
PYPY = platform.python_implementation() == "PyPy"
builtin_code_type = None
if PYPY:
# builtin-code objects only exist in pypy
builtin_code_type = type(float.__new__.__code__)
_extract_code_globals_cache = weakref.WeakKeyDictionary()
def _get_or_create_tracker_id(class_def):
with _DYNAMIC_CLASS_TRACKER_LOCK:
class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS.get(class_def)
if class_tracker_id is None:
class_tracker_id = uuid.uuid4().hex
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
_DYNAMIC_CLASS_TRACKER_BY_ID[class_tracker_id] = class_def
return class_tracker_id
def _lookup_class_or_track(class_tracker_id, class_def):
if class_tracker_id is not None:
with _DYNAMIC_CLASS_TRACKER_LOCK:
orig_class_def = class_def
class_def = _DYNAMIC_CLASS_TRACKER_BY_ID.setdefault(
class_tracker_id, class_def)
_DYNAMIC_CLASS_TRACKER_BY_CLASS[class_def] = class_tracker_id
# Check if we are reusing a previous class_def
if orig_class_def is not class_def:
# Remember the class_def is being reused
_DYNAMIC_CLASS_TRACKER_REUSING.add(class_def)
return class_def
def register_pickle_by_value(module):
"""Register a module to make it functions and classes picklable by value.
By default, functions and classes that are attributes of an importable
module are to be pickled by reference, that is relying on re-importing
the attribute from the module at load time.
If `register_pickle_by_value(module)` is called, all its functions and
classes are subsequently to be pickled by value, meaning that they can
be loaded in Python processes where the module is not importable.
This is especially useful when developing a module in a distributed
execution environment: restarting the client Python process with the new
source code is enough: there is no need to re-install the new version
of the module on all the worker nodes nor to restart the workers.
Note: this feature is considered experimental. See the cloudpickle
README.md file for more details and limitations.
"""
if not isinstance(module, types.ModuleType):
raise ValueError(
f"Input should be a module object, got {str(module)} instead"
)
# In the future, cloudpickle may need a way to access any module registered
# for pickling by value in order to introspect relative imports inside
# functions pickled by value. (see
# https://github.com/cloudpipe/cloudpickle/pull/417#issuecomment-873684633).
# This access can be ensured by checking that module is present in
# sys.modules at registering time and assuming that it will still be in
# there when accessed during pickling. Another alternative would be to
# store a weakref to the module. Even though cloudpickle does not implement
# this introspection yet, in order to avoid a possible breaking change
# later, we still enforce the presence of module inside sys.modules.
if module.__name__ not in sys.modules:
raise ValueError(
f"{module} was not imported correctly, have you used an "
f"`import` statement to access it?"
)
_PICKLE_BY_VALUE_MODULES.add(module.__name__)
def unregister_pickle_by_value(module):
"""Unregister that the input module should be pickled by value."""
if not isinstance(module, types.ModuleType):
raise ValueError(
f"Input should be a module object, got {str(module)} instead"
)
if module.__name__ not in _PICKLE_BY_VALUE_MODULES:
raise ValueError(f"{module} is not registered for pickle by value")
else:
_PICKLE_BY_VALUE_MODULES.remove(module.__name__)
def list_registry_pickle_by_value():
return _PICKLE_BY_VALUE_MODULES.copy()
def _is_registered_pickle_by_value(module):
module_name = module.__name__
if module_name in _PICKLE_BY_VALUE_MODULES:
return True
while True:
parent_name = module_name.rsplit(".", 1)[0]
if parent_name == module_name:
break
if parent_name in _PICKLE_BY_VALUE_MODULES:
return True
module_name = parent_name
return False
def _whichmodule(obj, name):
"""Find the module an object belongs to.
This function differs from ``pickle.whichmodule`` in two ways:
- it does not mangle the cases where obj's module is __main__ and obj was
not found in any module.
- Errors arising during module introspection are ignored, as those errors
are considered unwanted side effects.
"""
if sys.version_info[:2] < (3, 7) and isinstance(obj, typing.TypeVar): # pragma: no branch # noqa
# Workaround bug in old Python versions: prior to Python 3.7,
# T.__module__ would always be set to "typing" even when the TypeVar T
# would be defined in a different module.
if name is not None and getattr(typing, name, None) is obj:
# Built-in TypeVar defined in typing such as AnyStr
return 'typing'
else:
# User defined or third-party TypeVar: __module__ attribute is
# irrelevant, thus trigger a exhaustive search for obj in all
# modules.
module_name = None
else:
module_name = getattr(obj, '__module__', None)
if module_name is not None:
return module_name
# Protect the iteration by using a copy of sys.modules against dynamic
# modules that trigger imports of other modules upon calls to getattr or
# other threads importing at the same time.
for module_name, module in sys.modules.copy().items():
# Some modules such as coverage can inject non-module objects inside
# sys.modules
if (
module_name == '__main__' or
module is None or
not isinstance(module, types.ModuleType)
):
continue
try:
if _getattribute(module, name)[0] is obj:
return module_name
except Exception:
pass
return None
def _should_pickle_by_reference(obj, name=None):
"""Test whether an function or a class should be pickled by reference
Pickling by reference means by that the object (typically a function or a
class) is an attribute of a module that is assumed to be importable in the
target Python environment. Loading will therefore rely on importing the
module and then calling `getattr` on it to access the function or class.
Pickling by reference is the only option to pickle functions and classes
in the standard library. In cloudpickle the alternative option is to
pickle by value (for instance for interactively or locally defined
functions and classes or for attributes of modules that have been
explicitly registered to be pickled by value.
"""
if isinstance(obj, types.FunctionType) or issubclass(type(obj), type):
module_and_name = _lookup_module_and_qualname(obj, name=name)
if module_and_name is None:
return False
module, name = module_and_name
return not _is_registered_pickle_by_value(module)
elif isinstance(obj, types.ModuleType):
# We assume that sys.modules is primarily used as a cache mechanism for
# the Python import machinery. Checking if a module has been added in
# is sys.modules therefore a cheap and simple heuristic to tell us
# whether we can assume that a given module could be imported by name
# in another Python process.
if _is_registered_pickle_by_value(obj):
return False
return obj.__name__ in sys.modules
else:
raise TypeError(
"cannot check importability of {} instances".format(
type(obj).__name__)
)
def _lookup_module_and_qualname(obj, name=None):
if name is None:
name = getattr(obj, '__qualname__', None)
if name is None: # pragma: no cover
# This used to be needed for Python 2.7 support but is probably not
# needed anymore. However we keep the __name__ introspection in case
# users of cloudpickle rely on this old behavior for unknown reasons.
name = getattr(obj, '__name__', None)
module_name = _whichmodule(obj, name)
if module_name is None:
# In this case, obj.__module__ is None AND obj was not found in any
# imported module. obj is thus treated as dynamic.
return None
if module_name == "__main__":
return None
# Note: if module_name is in sys.modules, the corresponding module is
# assumed importable at unpickling time. See #357
module = sys.modules.get(module_name, None)
if module is None:
# The main reason why obj's module would not be imported is that this
# module has been dynamically created, using for example
# types.ModuleType. The other possibility is that module was removed
# from sys.modules after obj was created/imported. But this case is not
# supported, as the standard pickle does not support it either.
return None
try:
obj2, parent = _getattribute(module, name)
except AttributeError:
# obj was not found inside the module it points to
return None
if obj2 is not obj:
return None
return module, name
def _extract_code_globals(co):
"""
Find all globals names read or written to by codeblock co
"""
out_names = _extract_code_globals_cache.get(co)
if out_names is None:
# We use a dict with None values instead of a set to get a
# deterministic order (assuming Python 3.6+) and avoid introducing
# non-deterministic pickle bytes as a results.
out_names = {name: None for name in _walk_global_ops(co)}
# Declaring a function inside another one using the "def ..."
# syntax generates a constant code object corresponding to the one
# of the nested function's As the nested function may itself need
# global variables, we need to introspect its code, extract its
# globals, (look for code object in it's co_consts attribute..) and
# add the result to code_globals
if co.co_consts:
for const in co.co_consts:
if isinstance(const, types.CodeType):
out_names.update(_extract_code_globals(const))
_extract_code_globals_cache[co] = out_names
return out_names
def _find_imported_submodules(code, top_level_dependencies):
"""
Find currently imported submodules used by a function.
Submodules used by a function need to be detected and referenced for the
function to work correctly at depickling time. Because submodules can be
referenced as attribute of their parent package (``package.submodule``), we
need a special introspection technique that does not rely on GLOBAL-related
opcodes to find references of them in a code object.
Example:
```
import concurrent.futures
import cloudpickle
def func():
x = concurrent.futures.ThreadPoolExecutor
if __name__ == '__main__':
cloudpickle.dumps(func)
```
The globals extracted by cloudpickle in the function's state include the
concurrent package, but not its submodule (here, concurrent.futures), which
is the module used by func. Find_imported_submodules will detect the usage
of concurrent.futures. Saving this module alongside with func will ensure
that calling func once depickled does not fail due to concurrent.futures
not being imported
"""
subimports = []
# check if any known dependency is an imported package
for x in top_level_dependencies:
if (isinstance(x, types.ModuleType) and
hasattr(x, '__package__') and x.__package__):
# check if the package has any currently loaded sub-imports
prefix = x.__name__ + '.'
# A concurrent thread could mutate sys.modules,
# make sure we iterate over a copy to avoid exceptions
for name in list(sys.modules):
# Older versions of pytest will add a "None" module to
# sys.modules.
if name is not None and name.startswith(prefix):
# check whether the function can address the sub-module
tokens = set(name[len(prefix):].split('.'))
if not tokens - set(code.co_names):
subimports.append(sys.modules[name])
return subimports
def cell_set(cell, value):
"""Set the value of a closure cell.
The point of this function is to set the cell_contents attribute of a cell
after its creation. This operation is necessary in case the cell contains a
reference to the function the cell belongs to, as when calling the
function's constructor
``f = types.FunctionType(code, globals, name, argdefs, closure)``,
closure will not be able to contain the yet-to-be-created f.
In Python3.7, cell_contents is writeable, so setting the contents of a cell
can be done simply using
>>> cell.cell_contents = value
In earlier Python3 versions, the cell_contents attribute of a cell is read
only, but this limitation can be worked around by leveraging the Python 3
``nonlocal`` keyword.
In Python2 however, this attribute is read only, and there is no
``nonlocal`` keyword. For this reason, we need to come up with more
complicated hacks to set this attribute.
The chosen approach is to create a function with a STORE_DEREF opcode,
which sets the content of a closure variable. Typically:
>>> def inner(value):
... lambda: cell # the lambda makes cell a closure
... cell = value # cell is a closure, so this triggers a STORE_DEREF
(Note that in Python2, A STORE_DEREF can never be triggered from an inner
function. The function g for example here
>>> def f(var):
... def g():
... var += 1
... return g
will not modify the closure variable ``var```inplace, but instead try to
load a local variable var and increment it. As g does not assign the local
variable ``var`` any initial value, calling f(1)() will fail at runtime.)
Our objective is to set the value of a given cell ``cell``. So we need to
somewhat reference our ``cell`` object into the ``inner`` function so that
this object (and not the smoke cell of the lambda function) gets affected
by the STORE_DEREF operation.
In inner, ``cell`` is referenced as a cell variable (an enclosing variable
that is referenced by the inner function). If we create a new function
cell_set with the exact same code as ``inner``, but with ``cell`` marked as
a free variable instead, the STORE_DEREF will be applied on its closure -
``cell``, which we can specify explicitly during construction! The new
cell_set variable thus actually sets the contents of a specified cell!
Note: we do not make use of the ``nonlocal`` keyword to set the contents of
a cell in early python3 versions to limit possible syntax errors in case
test and checker libraries decide to parse the whole file.
"""
if sys.version_info[:2] >= (3, 7): # pragma: no branch
cell.cell_contents = value
else:
_cell_set = types.FunctionType(
_cell_set_template_code, {}, '_cell_set', (), (cell,),)
_cell_set(value)
def _make_cell_set_template_code():
def _cell_set_factory(value):
lambda: cell
cell = value
co = _cell_set_factory.__code__
_cell_set_template_code = types.CodeType(
co.co_argcount,
co.co_kwonlyargcount, # Python 3 only argument
co.co_nlocals,
co.co_stacksize,
co.co_flags,
co.co_code,
co.co_consts,
co.co_names,
co.co_varnames,
co.co_filename,
co.co_name,
co.co_firstlineno,
co.co_lnotab,
co.co_cellvars, # co_freevars is initialized with co_cellvars
(), # co_cellvars is made empty
)
return _cell_set_template_code
if sys.version_info[:2] < (3, 7):
_cell_set_template_code = _make_cell_set_template_code()
# relevant opcodes
STORE_GLOBAL = opcode.opmap['STORE_GLOBAL']
DELETE_GLOBAL = opcode.opmap['DELETE_GLOBAL']
LOAD_GLOBAL = opcode.opmap['LOAD_GLOBAL']
GLOBAL_OPS = (STORE_GLOBAL, DELETE_GLOBAL, LOAD_GLOBAL)
HAVE_ARGUMENT = dis.HAVE_ARGUMENT
EXTENDED_ARG = dis.EXTENDED_ARG
_BUILTIN_TYPE_NAMES = {}
for k, v in types.__dict__.items():
if type(v) is type:
_BUILTIN_TYPE_NAMES[v] = k
def _builtin_type(name):
if name == "ClassType": # pragma: no cover
# Backward compat to load pickle files generated with cloudpickle
# < 1.3 even if loading pickle files from older versions is not
# officially supported.
return type
return getattr(types, name)
def _walk_global_ops(code):
"""
Yield referenced name for all global-referencing instructions in *code*.
"""
for instr in dis.get_instructions(code):
op = instr.opcode
if op in GLOBAL_OPS:
yield instr.argval
def _extract_class_dict(cls):
"""Retrieve a copy of the dict of a class without the inherited methods"""
clsdict = dict(cls.__dict__) # copy dict proxy to a dict
if len(cls.__bases__) == 1:
inherited_dict = cls.__bases__[0].__dict__
else:
inherited_dict = {}
for base in reversed(cls.__bases__):
inherited_dict.update(base.__dict__)
to_remove = []
for name, value in clsdict.items():
try:
base_value = inherited_dict[name]
if value is base_value:
to_remove.append(name)
except KeyError:
pass
for name in to_remove:
clsdict.pop(name)
return clsdict
if sys.version_info[:2] < (3, 7): # pragma: no branch
def _is_parametrized_type_hint(obj):
# This is very cheap but might generate false positives. So try to
# narrow it down is good as possible.
type_module = getattr(type(obj), '__module__', None)
from_typing_extensions = type_module == 'typing_extensions'
from_typing = type_module == 'typing'
# general typing Constructs
is_typing = getattr(obj, '__origin__', None) is not None
# typing_extensions.Literal
is_literal = (
(getattr(obj, '__values__', None) is not None)
and from_typing_extensions
)
# typing_extensions.Final
is_final = (
(getattr(obj, '__type__', None) is not None)
and from_typing_extensions
)
# typing.ClassVar
is_classvar = (
(getattr(obj, '__type__', None) is not None) and from_typing
)
# typing.Union/Tuple for old Python 3.5
is_union = getattr(obj, '__union_params__', None) is not None
is_tuple = getattr(obj, '__tuple_params__', None) is not None
is_callable = (
getattr(obj, '__result__', None) is not None and
getattr(obj, '__args__', None) is not None
)
return any((is_typing, is_literal, is_final, is_classvar, is_union,
is_tuple, is_callable))
def _create_parametrized_type_hint(origin, args):
return origin[args]
else:
_is_parametrized_type_hint = None
_create_parametrized_type_hint = None
def parametrized_type_hint_getinitargs(obj):
# The distorted type check sematic for typing construct becomes:
# ``type(obj) is type(TypeHint)``, which means "obj is a
# parametrized TypeHint"
if type(obj) is type(Literal): # pragma: no branch
initargs = (Literal, obj.__values__)
elif type(obj) is type(Final): # pragma: no branch
initargs = (Final, obj.__type__)
elif type(obj) is type(ClassVar):
initargs = (ClassVar, obj.__type__)
elif type(obj) is type(Generic):
initargs = (obj.__origin__, obj.__args__)
elif type(obj) is type(Union):
initargs = (Union, obj.__args__)
elif type(obj) is type(Tuple):
initargs = (Tuple, obj.__args__)
elif type(obj) is type(Callable):
(*args, result) = obj.__args__
if len(args) == 1 and args[0] is Ellipsis:
args = Ellipsis
else:
args = list(args)
initargs = (Callable, (args, result))
else: # pragma: no cover
raise pickle.PicklingError(
f"Cloudpickle Error: Unknown type {type(obj)}"
)
return initargs
# Tornado support
def is_tornado_coroutine(func):
"""
Return whether *func* is a Tornado coroutine function.
Running coroutines are not supported.
"""
if 'tornado.gen' not in sys.modules:
return False
gen = sys.modules['tornado.gen']
if not hasattr(gen, "is_coroutine_function"):
# Tornado version is too old
return False
return gen.is_coroutine_function(func)
def _rebuild_tornado_coroutine(func):
from tornado import gen
return gen.coroutine(func)
# including pickles unloading functions in this namespace
load = pickle.load
loads = pickle.loads
def subimport(name):
# We cannot do simply: `return __import__(name)`: Indeed, if ``name`` is
# the name of a submodule, __import__ will return the top-level root module
# of this submodule. For instance, __import__('os.path') returns the `os`
# module.
__import__(name)
return sys.modules[name]
def dynamic_subimport(name, vars):
mod = types.ModuleType(name)
mod.__dict__.update(vars)
mod.__dict__['__builtins__'] = builtins.__dict__
return mod
def _gen_ellipsis():
return Ellipsis
def _gen_not_implemented():
return NotImplemented
def _get_cell_contents(cell):
try:
return cell.cell_contents
except ValueError:
# sentinel used by ``_fill_function`` which will leave the cell empty
return _empty_cell_value
def instance(cls):
"""Create a new instance of a class.
Parameters
----------
cls : type
The class to create an instance of.
Returns
-------
instance : cls
A new instance of ``cls``.
"""
return cls()
@instance
class _empty_cell_value:
"""sentinel for empty closures
"""
@classmethod
def __reduce__(cls):
return cls.__name__
def _fill_function(*args):
"""Fills in the rest of function data into the skeleton function object
The skeleton itself is create by _make_skel_func().
"""
if len(args) == 2:
func = args[0]
state = args[1]
elif len(args) == 5:
# Backwards compat for cloudpickle v0.4.0, after which the `module`
# argument was introduced
func = args[0]
keys = ['globals', 'defaults', 'dict', 'closure_values']
state = dict(zip(keys, args[1:]))
elif len(args) == 6:
# Backwards compat for cloudpickle v0.4.1, after which the function
# state was passed as a dict to the _fill_function it-self.
func = args[0]
keys = ['globals', 'defaults', 'dict', 'module', 'closure_values']
state = dict(zip(keys, args[1:]))
else:
raise ValueError(f'Unexpected _fill_value arguments: {args!r}')
# - At pickling time, any dynamic global variable used by func is
# serialized by value (in state['globals']).
# - At unpickling time, func's __globals__ attribute is initialized by
# first retrieving an empty isolated namespace that will be shared
# with other functions pickled from the same original module
# by the same CloudPickler instance and then updated with the
# content of state['globals'] to populate the shared isolated
# namespace with all the global variables that are specifically
# referenced for this function.
func.__globals__.update(state['globals'])
func.__defaults__ = state['defaults']
func.__dict__ = state['dict']
if 'annotations' in state:
func.__annotations__ = state['annotations']
if 'doc' in state:
func.__doc__ = state['doc']
if 'name' in state:
func.__name__ = state['name']
if 'module' in state:
func.__module__ = state['module']
if 'qualname' in state:
func.__qualname__ = state['qualname']
if 'kwdefaults' in state:
func.__kwdefaults__ = state['kwdefaults']
# _cloudpickle_subimports is a set of submodules that must be loaded for
# the pickled function to work correctly at unpickling time. Now that these
# submodules are depickled (hence imported), they can be removed from the
# object's state (the object state only served as a reference holder to
# these submodules)
if '_cloudpickle_submodules' in state:
state.pop('_cloudpickle_submodules')
cells = func.__closure__
if cells is not None:
for cell, value in zip(cells, state['closure_values']):
if value is not _empty_cell_value:
cell_set(cell, value)
return func
def _make_function(code, globals, name, argdefs, closure):
# Setting __builtins__ in globals is needed for nogil CPython.
globals["__builtins__"] = __builtins__
return types.FunctionType(code, globals, name, argdefs, closure)
def _make_empty_cell():
if False:
# trick the compiler into creating an empty cell in our lambda
cell = None
raise AssertionError('this route should not be executed')
return (lambda: cell).__closure__[0]
def _make_cell(value=_empty_cell_value):
cell = _make_empty_cell()
if value is not _empty_cell_value:
cell_set(cell, value)
return cell
def _make_skel_func(code, cell_count, base_globals=None):
""" Creates a skeleton function object that contains just the provided
code and the correct number of cells in func_closure. All other
func attributes (e.g. func_globals) are empty.
"""
# This function is deprecated and should be removed in cloudpickle 1.7
warnings.warn(
"A pickle file created using an old (<=1.4.1) version of cloudpickle "
"is currently being loaded. This is not supported by cloudpickle and "
"will break in cloudpickle 1.7", category=UserWarning
)
# This is backward-compatibility code: for cloudpickle versions between
# 0.5.4 and 0.7, base_globals could be a string or None. base_globals
# should now always be a dictionary.
if base_globals is None or isinstance(base_globals, str):
base_globals = {}
base_globals['__builtins__'] = __builtins__
closure = (
tuple(_make_empty_cell() for _ in range(cell_count))
if cell_count >= 0 else
None
)
return types.FunctionType(code, base_globals, None, None, closure)
def _make_skeleton_class(type_constructor, name, bases, type_kwargs,
class_tracker_id, extra):
"""Build dynamic class with an empty __dict__ to be filled once memoized
If class_tracker_id is not None, try to lookup an existing class definition
matching that id. If none is found, track a newly reconstructed class
definition under that id so that other instances stemming from the same
class id will also reuse this class definition.
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
skeleton_class = types.new_class(
name, bases, {'metaclass': type_constructor},
lambda ns: ns.update(type_kwargs)
)
return _lookup_class_or_track(class_tracker_id, skeleton_class)
def _rehydrate_skeleton_class(skeleton_class, class_dict):
"""Put attributes from `class_dict` back on `skeleton_class`.
See CloudPickler.save_dynamic_class for more info.
"""
registry = None
for attrname, attr in class_dict.items():
if attrname == "_abc_impl":
registry = attr
else:
setattr(skeleton_class, attrname, attr)
if registry is not None:
for subclass in registry:
skeleton_class.register(subclass)
return skeleton_class
def _make_skeleton_enum(bases, name, qualname, members, module,
class_tracker_id, extra):
"""Build dynamic enum with an empty __dict__ to be filled once memoized
The creation of the enum class is inspired by the code of
EnumMeta._create_.
If class_tracker_id is not None, try to lookup an existing enum definition
matching that id. If none is found, track a newly reconstructed enum
definition under that id so that other instances stemming from the same
class id will also reuse this enum definition.
The "extra" variable is meant to be a dict (or None) that can be used for
forward compatibility shall the need arise.
"""
# enums always inherit from their base Enum class at the last position in
# the list of base classes:
enum_base = bases[-1]
metacls = enum_base.__class__
classdict = metacls.__prepare__(name, bases)
for member_name, member_value in members.items():
classdict[member_name] = member_value
enum_class = metacls.__new__(metacls, name, bases, classdict)
enum_class.__module__ = module
enum_class.__qualname__ = qualname
return _lookup_class_or_track(class_tracker_id, enum_class)
def _make_typevar(name, bound, constraints, covariant, contravariant,
class_tracker_id):
tv = typing.TypeVar(
name, *constraints, bound=bound,
covariant=covariant, contravariant=contravariant
)
if class_tracker_id is not None:
return _lookup_class_or_track(class_tracker_id, tv)
else: # pragma: nocover
# Only for Python 3.5.3 compat.
return tv
def _decompose_typevar(obj):
return (
obj.__name__, obj.__bound__, obj.__constraints__,
obj.__covariant__, obj.__contravariant__,
_get_or_create_tracker_id(obj),
)
def _typevar_reduce(obj):
# TypeVar instances require the module information hence why we
# are not using the _should_pickle_by_reference directly
module_and_name = _lookup_module_and_qualname(obj, name=obj.__name__)
if module_and_name is None:
return (_make_typevar, _decompose_typevar(obj))
elif _is_registered_pickle_by_value(module_and_name[0]):
return (_make_typevar, _decompose_typevar(obj))
return (getattr, module_and_name)
def _get_bases(typ):
if '__orig_bases__' in getattr(typ, '__dict__', {}):
# For generic types (see PEP 560)
# Note that simply checking `hasattr(typ, '__orig_bases__')` is not
# correct. Subclasses of a fully-parameterized generic class does not
# have `__orig_bases__` defined, but `hasattr(typ, '__orig_bases__')`
# will return True because it's defined in the base class.
bases_attr = '__orig_bases__'
else:
# For regular class objects
bases_attr = '__bases__'
return getattr(typ, bases_attr)
def _make_dict_keys(obj, is_ordered=False):
if is_ordered:
return OrderedDict.fromkeys(obj).keys()
else:
return dict.fromkeys(obj).keys()
def _make_dict_values(obj, is_ordered=False):
if is_ordered:
return OrderedDict((i, _) for i, _ in enumerate(obj)).values()
else:
return {i: _ for i, _ in enumerate(obj)}.values()
def _make_dict_items(obj, is_ordered=False):
if is_ordered:
return OrderedDict(obj).items()
else:
return obj.items()
"""
New, fast version of the CloudPickler.
This new CloudPickler class can now extend the fast C Pickler instead of the
previous Python implementation of the Pickler class. Because this functionality
is only available for Python versions 3.8+, a lot of backward-compatibility
code is also removed.
Note that the C Pickler subclassing API is CPython-specific. Therefore, some
guards present in cloudpickle.py that were written to handle PyPy specificities
are not present in cloudpickle_fast.py
"""
import _collections_abc
import abc
import copyreg
import io
import itertools
import logging
import sys
import struct
import types
import weakref
import typing
from enum import Enum
from collections import ChainMap, OrderedDict
from .compat import pickle, Pickler
from .cloudpickle import (
_extract_code_globals, _BUILTIN_TYPE_NAMES, DEFAULT_PROTOCOL,
_find_imported_submodules, _get_cell_contents, _should_pickle_by_reference,
_builtin_type, _get_or_create_tracker_id, _make_skeleton_class,
_make_skeleton_enum, _extract_class_dict, dynamic_subimport, subimport,
_typevar_reduce, _get_bases, _make_cell, _make_empty_cell, CellType,
_is_parametrized_type_hint, PYPY, cell_set,
parametrized_type_hint_getinitargs, _create_parametrized_type_hint,
builtin_code_type,
_make_dict_keys, _make_dict_values, _make_dict_items, _make_function,
_DYNAMIC_CLASS_TRACKER_REUSING
)
if pickle.HIGHEST_PROTOCOL >= 5:
# Shorthands similar to pickle.dump/pickle.dumps
def dump(obj, file, protocol=None, buffer_callback=None):
"""Serialize obj as bytes streamed into file
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
CloudPickler(
file, protocol=protocol, buffer_callback=buffer_callback
).dump(obj)
def dumps(obj, protocol=None, buffer_callback=None):
"""Serialize obj as a string of bytes allocated in memory
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
with io.BytesIO() as file:
cp = CloudPickler(
file, protocol=protocol, buffer_callback=buffer_callback
)
cp.dump(obj)
return file.getvalue()
else:
# Shorthands similar to pickle.dump/pickle.dumps
def dump(obj, file, protocol=None):
"""Serialize obj as bytes streamed into file
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
CloudPickler(file, protocol=protocol).dump(obj)
def dumps(obj, protocol=None):
"""Serialize obj as a string of bytes allocated in memory
protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
pickle.HIGHEST_PROTOCOL. This setting favors maximum communication
speed between processes running the same Python version.
Set protocol=pickle.DEFAULT_PROTOCOL instead if you need to ensure
compatibility with older versions of Python.
"""
with io.BytesIO() as file:
cp = CloudPickler(file, protocol=protocol)
cp.dump(obj)
return file.getvalue()
load, loads = pickle.load, pickle.loads
# COLLECTION OF OBJECTS __getnewargs__-LIKE METHODS
# -------------------------------------------------
def _class_getnewargs(obj):
type_kwargs = {}
if "__slots__" in obj.__dict__:
type_kwargs["__slots__"] = obj.__slots__
__dict__ = obj.__dict__.get('__dict__', None)
if isinstance(__dict__, property):
type_kwargs['__dict__'] = __dict__
return (type(obj), obj.__name__, _get_bases(obj), type_kwargs,
_get_or_create_tracker_id(obj), None)
def _enum_getnewargs(obj):
members = {e.name: e.value for e in obj}
return (obj.__bases__, obj.__name__, obj.__qualname__, members,
obj.__module__, _get_or_create_tracker_id(obj), None)
# COLLECTION OF OBJECTS RECONSTRUCTORS
# ------------------------------------
def _file_reconstructor(retval):
return retval
# COLLECTION OF OBJECTS STATE GETTERS
# -----------------------------------
def _function_getstate(func):
# - Put func's dynamic attributes (stored in func.__dict__) in state. These
# attributes will be restored at unpickling time using
# f.__dict__.update(state)
# - Put func's members into slotstate. Such attributes will be restored at
# unpickling time by iterating over slotstate and calling setattr(func,
# slotname, slotvalue)
slotstate = {
"__name__": func.__name__,
"__qualname__": func.__qualname__,
"__annotations__": func.__annotations__,
"__kwdefaults__": func.__kwdefaults__,
"__defaults__": func.__defaults__,
"__module__": func.__module__,
"__doc__": func.__doc__,
"__closure__": func.__closure__,
}
f_globals_ref = _extract_code_globals(func.__code__)
f_globals = {k: func.__globals__[k] for k in f_globals_ref if k in
func.__globals__}
closure_values = (
list(map(_get_cell_contents, func.__closure__))
if func.__closure__ is not None else ()
)
# Extract currently-imported submodules used by func. Storing these modules
# in a smoke _cloudpickle_subimports attribute of the object's state will
# trigger the side effect of importing these modules at unpickling time
# (which is necessary for func to work correctly once depickled)
slotstate["_cloudpickle_submodules"] = _find_imported_submodules(
func.__code__, itertools.chain(f_globals.values(), closure_values))
slotstate["__globals__"] = f_globals
state = func.__dict__
return state, slotstate
def _class_getstate(obj):
clsdict = _extract_class_dict(obj)
clsdict.pop('__weakref__', None)
if issubclass(type(obj), abc.ABCMeta):
# If obj is an instance of an ABCMeta subclass, don't pickle the
# cache/negative caches populated during isinstance/issubclass
# checks, but pickle the list of registered subclasses of obj.
clsdict.pop('_abc_cache', None)
clsdict.pop('_abc_negative_cache', None)
clsdict.pop('_abc_negative_cache_version', None)
registry = clsdict.pop('_abc_registry', None)
if registry is None:
# in Python3.7+, the abc caches and registered subclasses of a
# class are bundled into the single _abc_impl attribute
clsdict.pop('_abc_impl', None)
(registry, _, _, _) = abc._get_dump(obj)
clsdict["_abc_impl"] = [subclass_weakref()
for subclass_weakref in registry]
else:
# In the above if clause, registry is a set of weakrefs -- in
# this case, registry is a WeakSet
clsdict["_abc_impl"] = [type_ for type_ in registry]
if "__slots__" in clsdict:
# pickle string length optimization: member descriptors of obj are
# created automatically from obj's __slots__ attribute, no need to
# save them in obj's state
if isinstance(obj.__slots__, str):
clsdict.pop(obj.__slots__)
else:
for k in obj.__slots__:
clsdict.pop(k, None)
clsdict.pop('__dict__', None) # unpicklable property object
return (clsdict, {})
def _enum_getstate(obj):
clsdict, slotstate = _class_getstate(obj)
members = {e.name: e.value for e in obj}
# Cleanup the clsdict that will be passed to _rehydrate_skeleton_class:
# Those attributes are already handled by the metaclass.
for attrname in ["_generate_next_value_", "_member_names_",
"_member_map_", "_member_type_",
"_value2member_map_"]:
clsdict.pop(attrname, None)
for member in members:
clsdict.pop(member)
# Special handling of Enum subclasses
return clsdict, slotstate
# COLLECTIONS OF OBJECTS REDUCERS
# -------------------------------
# A reducer is a function taking a single argument (obj), and that returns a
# tuple with all the necessary data to re-construct obj. Apart from a few
# exceptions (list, dict, bytes, int, etc.), a reducer is necessary to
# correctly pickle an object.
# While many built-in objects (Exceptions objects, instances of the "object"
# class, etc), are shipped with their own built-in reducer (invoked using
# obj.__reduce__), some do not. The following methods were created to "fill
# these holes".
def _code_reduce(obj):
"""codeobject reducer"""
# If you are not sure about the order of arguments, take a look at help
# of the specific type from types, for example:
# >>> from types import CodeType
# >>> help(CodeType)
if hasattr(obj, "co_exceptiontable"): # pragma: no branch
# Python 3.11 and later: there are some new attributes
# related to the enhanced exceptions.
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
obj.co_varnames, obj.co_filename, obj.co_name, obj.co_qualname,
obj.co_firstlineno, obj.co_linetable, obj.co_exceptiontable,
obj.co_freevars, obj.co_cellvars,
)
elif hasattr(obj, "co_linetable"): # pragma: no branch
# Python 3.10 and later: obj.co_lnotab is deprecated and constructor
# expects obj.co_linetable instead.
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_linetable, obj.co_freevars,
obj.co_cellvars
)
elif hasattr(obj, "co_nmeta"): # pragma: no cover
# "nogil" Python: modified attributes from 3.9
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_framesize,
obj.co_ndefaultargs, obj.co_nmeta,
obj.co_flags, obj.co_code, obj.co_consts,
obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_lnotab, obj.co_exc_handlers,
obj.co_jump_table, obj.co_freevars, obj.co_cellvars,
obj.co_free2reg, obj.co_cell2reg
)
elif hasattr(obj, "co_posonlyargcount"):
# Backward compat for 3.9 and older
args = (
obj.co_argcount, obj.co_posonlyargcount,
obj.co_kwonlyargcount, obj.co_nlocals, obj.co_stacksize,
obj.co_flags, obj.co_code, obj.co_consts, obj.co_names,
obj.co_varnames, obj.co_filename, obj.co_name,
obj.co_firstlineno, obj.co_lnotab, obj.co_freevars,
obj.co_cellvars
)
else:
# Backward compat for even older versions of Python
args = (
obj.co_argcount, obj.co_kwonlyargcount, obj.co_nlocals,
obj.co_stacksize, obj.co_flags, obj.co_code, obj.co_consts,
obj.co_names, obj.co_varnames, obj.co_filename,
obj.co_name, obj.co_firstlineno, obj.co_lnotab,
obj.co_freevars, obj.co_cellvars
)
return types.CodeType, args
def _cell_reduce(obj):
"""Cell (containing values of a function's free variables) reducer"""
try:
obj.cell_contents
except ValueError: # cell is empty
return _make_empty_cell, ()
else:
return _make_cell, (obj.cell_contents, )
def _classmethod_reduce(obj):
orig_func = obj.__func__
return type(obj), (orig_func,)
def _file_reduce(obj):
"""Save a file"""
import io
if not hasattr(obj, "name") or not hasattr(obj, "mode"):
raise pickle.PicklingError(
"Cannot pickle files that do not map to an actual file"
)
if obj is sys.stdout:
return getattr, (sys, "stdout")
if obj is sys.stderr:
return getattr, (sys, "stderr")
if obj is sys.stdin:
raise pickle.PicklingError("Cannot pickle standard input")
if obj.closed:
raise pickle.PicklingError("Cannot pickle closed files")
if hasattr(obj, "isatty") and obj.isatty():
raise pickle.PicklingError(
"Cannot pickle files that map to tty objects"
)
if "r" not in obj.mode and "+" not in obj.mode:
raise pickle.PicklingError(
"Cannot pickle files that are not opened for reading: %s"
% obj.mode
)
name = obj.name
retval = io.StringIO()
try:
# Read the whole file
curloc = obj.tell()
obj.seek(0)
contents = obj.read()
obj.seek(curloc)
except IOError as e:
raise pickle.PicklingError(
"Cannot pickle file %s as it cannot be read" % name
) from e
retval.write(contents)
retval.seek(curloc)
retval.name = name
return _file_reconstructor, (retval,)
def _getset_descriptor_reduce(obj):
return getattr, (obj.__objclass__, obj.__name__)
def _mappingproxy_reduce(obj):
return types.MappingProxyType, (dict(obj),)
def _memoryview_reduce(obj):
return bytes, (obj.tobytes(),)
def _module_reduce(obj):
if _should_pickle_by_reference(obj):
return subimport, (obj.__name__,)
else:
# Some external libraries can populate the "__builtins__" entry of a
# module's `__dict__` with unpicklable objects (see #316). For that
# reason, we do not attempt to pickle the "__builtins__" entry, and
# restore a default value for it at unpickling time.
state = obj.__dict__.copy()
state.pop('__builtins__', None)
return dynamic_subimport, (obj.__name__, state)
def _method_reduce(obj):
return (types.MethodType, (obj.__func__, obj.__self__))
def _logger_reduce(obj):
return logging.getLogger, (obj.name,)
def _root_logger_reduce(obj):
return logging.getLogger, ()
def _property_reduce(obj):
return property, (obj.fget, obj.fset, obj.fdel, obj.__doc__)
def _weakset_reduce(obj):
return weakref.WeakSet, (list(obj),)
def _dynamic_class_reduce(obj):
"""
Save a class that can't be stored as module global.
This method is used to serialize classes that are defined inside
functions, or that otherwise can't be serialized as attribute lookups
from global modules.
"""
if Enum is not None and issubclass(obj, Enum):
return (
_make_skeleton_enum, _enum_getnewargs(obj), _enum_getstate(obj),
None, None, _class_setstate
)
else:
return (
_make_skeleton_class, _class_getnewargs(obj), _class_getstate(obj),
None, None, _class_setstate
)
def _class_reduce(obj):
"""Select the reducer depending on the dynamic nature of the class obj"""
if obj is type(None): # noqa
return type, (None,)
elif obj is type(Ellipsis):
return type, (Ellipsis,)
elif obj is type(NotImplemented):
return type, (NotImplemented,)
elif obj in _BUILTIN_TYPE_NAMES:
return _builtin_type, (_BUILTIN_TYPE_NAMES[obj],)
elif not _should_pickle_by_reference(obj):
return _dynamic_class_reduce(obj)
return NotImplemented
def _dict_keys_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_keys, (list(obj), )
def _dict_values_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_values, (list(obj), )
def _dict_items_reduce(obj):
return _make_dict_items, (dict(obj), )
def _odict_keys_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_keys, (list(obj), True)
def _odict_values_reduce(obj):
# Safer not to ship the full dict as sending the rest might
# be unintended and could potentially cause leaking of
# sensitive information
return _make_dict_values, (list(obj), True)
def _odict_items_reduce(obj):
return _make_dict_items, (dict(obj), True)
# COLLECTIONS OF OBJECTS STATE SETTERS
# ------------------------------------
# state setters are called at unpickling time, once the object is created and
# it has to be updated to how it was at unpickling time.
def _function_setstate(obj, state):
"""Update the state of a dynamic function.
As __closure__ and __globals__ are readonly attributes of a function, we
cannot rely on the native setstate routine of pickle.load_build, that calls
setattr on items of the slotstate. Instead, we have to modify them inplace.
"""
state, slotstate = state
obj.__dict__.update(state)
obj_globals = slotstate.pop("__globals__")
obj_closure = slotstate.pop("__closure__")
# _cloudpickle_subimports is a set of submodules that must be loaded for
# the pickled function to work correctly at unpickling time. Now that these
# submodules are depickled (hence imported), they can be removed from the
# object's state (the object state only served as a reference holder to
# these submodules)
slotstate.pop("_cloudpickle_submodules")
obj.__globals__.update(obj_globals)
obj.__globals__["__builtins__"] = __builtins__
if obj_closure is not None:
for i, cell in enumerate(obj_closure):
try:
value = cell.cell_contents
except ValueError: # cell is empty
continue
cell_set(obj.__closure__[i], value)
for k, v in slotstate.items():
setattr(obj, k, v)
def _class_setstate(obj, state):
# Check if class is being reused and needs bypass setstate logic.
if obj in _DYNAMIC_CLASS_TRACKER_REUSING:
return obj
state, slotstate = state
registry = None
for attrname, attr in state.items():
if attrname == "_abc_impl":
registry = attr
else:
setattr(obj, attrname, attr)
if registry is not None:
for subclass in registry:
obj.register(subclass)
return obj
class CloudPickler(Pickler):
# set of reducers defined and used by cloudpickle (private)
_dispatch_table = {}
_dispatch_table[classmethod] = _classmethod_reduce
_dispatch_table[io.TextIOWrapper] = _file_reduce
_dispatch_table[logging.Logger] = _logger_reduce
_dispatch_table[logging.RootLogger] = _root_logger_reduce
_dispatch_table[memoryview] = _memoryview_reduce
_dispatch_table[property] = _property_reduce
_dispatch_table[staticmethod] = _classmethod_reduce
_dispatch_table[CellType] = _cell_reduce
_dispatch_table[types.CodeType] = _code_reduce
_dispatch_table[types.GetSetDescriptorType] = _getset_descriptor_reduce
_dispatch_table[types.ModuleType] = _module_reduce
_dispatch_table[types.MethodType] = _method_reduce
_dispatch_table[types.MappingProxyType] = _mappingproxy_reduce
_dispatch_table[weakref.WeakSet] = _weakset_reduce
_dispatch_table[typing.TypeVar] = _typevar_reduce
_dispatch_table[_collections_abc.dict_keys] = _dict_keys_reduce
_dispatch_table[_collections_abc.dict_values] = _dict_values_reduce
_dispatch_table[_collections_abc.dict_items] = _dict_items_reduce
_dispatch_table[type(OrderedDict().keys())] = _odict_keys_reduce
_dispatch_table[type(OrderedDict().values())] = _odict_values_reduce
_dispatch_table[type(OrderedDict().items())] = _odict_items_reduce
_dispatch_table[abc.abstractmethod] = _classmethod_reduce
_dispatch_table[abc.abstractclassmethod] = _classmethod_reduce
_dispatch_table[abc.abstractstaticmethod] = _classmethod_reduce
_dispatch_table[abc.abstractproperty] = _property_reduce
dispatch_table = ChainMap(_dispatch_table, copyreg.dispatch_table)
# function reducers are defined as instance methods of CloudPickler
# objects, as they rely on a CloudPickler attribute (globals_ref)
def _dynamic_function_reduce(self, func):
"""Reduce a function that is not pickleable via attribute lookup."""
newargs = self._function_getnewargs(func)
state = _function_getstate(func)
return (_make_function, newargs, state, None, None,
_function_setstate)
def _function_reduce(self, obj):
"""Reducer for function objects.
If obj is a top-level attribute of a file-backed module, this
reducer returns NotImplemented, making the CloudPickler fallback to
traditional _pickle.Pickler routines to save obj. Otherwise, it reduces
obj using a custom cloudpickle reducer designed specifically to handle
dynamic functions.
As opposed to cloudpickle.py, There no special handling for builtin
pypy functions because cloudpickle_fast is CPython-specific.
"""
if _should_pickle_by_reference(obj):
return NotImplemented
else:
return self._dynamic_function_reduce(obj)
def _function_getnewargs(self, func):
code = func.__code__
# base_globals represents the future global namespace of func at
# unpickling time. Looking it up and storing it in
# CloudpiPickler.globals_ref allow functions sharing the same globals
# at pickling time to also share them once unpickled, at one condition:
# since globals_ref is an attribute of a CloudPickler instance, and
# that a new CloudPickler is created each time pickle.dump or
# pickle.dumps is called, functions also need to be saved within the
# same invocation of cloudpickle.dump/cloudpickle.dumps (for example:
# cloudpickle.dumps([f1, f2])). There is no such limitation when using
# CloudPickler.dump, as long as the multiple invocations are bound to
# the same CloudPickler.
base_globals = self.globals_ref.setdefault(id(func.__globals__), {})
if base_globals == {}:
# Add module attributes used to resolve relative imports
# instructions inside func.
for k in ["__package__", "__name__", "__path__", "__file__"]:
if k in func.__globals__:
base_globals[k] = func.__globals__[k]
# Do not bind the free variables before the function is created to
# avoid infinite recursion.
if func.__closure__ is None:
closure = None
else:
closure = tuple(
_make_empty_cell() for _ in range(len(code.co_freevars)))
return code, base_globals, None, None, closure
def dump(self, obj):
try:
return Pickler.dump(self, obj)
except RuntimeError as e:
if "recursion" in e.args[0]:
msg = (
"Could not pickle object as excessively deep recursion "
"required."
)
raise pickle.PicklingError(msg) from e
else:
raise
if pickle.HIGHEST_PROTOCOL >= 5:
def __init__(self, file, protocol=None, buffer_callback=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(
self, file, protocol=protocol, buffer_callback=buffer_callback
)
# map functions __globals__ attribute ids, to ensure that functions
# sharing the same global namespace at pickling time also share
# their global namespace at unpickling time.
self.globals_ref = {}
self.proto = int(protocol)
else:
def __init__(self, file, protocol=None):
if protocol is None:
protocol = DEFAULT_PROTOCOL
Pickler.__init__(self, file, protocol=protocol)
# map functions __globals__ attribute ids, to ensure that functions
# sharing the same global namespace at pickling time also share
# their global namespace at unpickling time.
self.globals_ref = {}
assert hasattr(self, 'proto')
if pickle.HIGHEST_PROTOCOL >= 5 and not PYPY:
# Pickler is the C implementation of the CPython pickler and therefore
# we rely on reduce_override method to customize the pickler behavior.
# `CloudPickler.dispatch` is only left for backward compatibility - note
# that when using protocol 5, `CloudPickler.dispatch` is not an
# extension of `Pickler.dispatch` dictionary, because CloudPickler
# subclasses the C-implemented Pickler, which does not expose a
# `dispatch` attribute. Earlier versions of the protocol 5 CloudPickler
# used `CloudPickler.dispatch` as a class-level attribute storing all
# reducers implemented by cloudpickle, but the attribute name was not a
# great choice given the meaning of `CloudPickler.dispatch` when
# `CloudPickler` extends the pure-python pickler.
dispatch = dispatch_table
# Implementation of the reducer_override callback, in order to
# efficiently serialize dynamic functions and classes by subclassing
# the C-implemented Pickler.
# TODO: decorrelate reducer_override (which is tied to CPython's
# implementation - would it make sense to backport it to pypy? - and
# pickle's protocol 5 which is implementation agnostic. Currently, the
# availability of both notions coincide on CPython's pickle and the
# pickle5 backport, but it may not be the case anymore when pypy
# implements protocol 5
def reducer_override(self, obj):
"""Type-agnostic reducing callback for function and classes.
For performance reasons, subclasses of the C _pickle.Pickler class
cannot register custom reducers for functions and classes in the
dispatch_table. Reducer for such types must instead implemented in
the special reducer_override method.
Note that method will be called for any object except a few
builtin-types (int, lists, dicts etc.), which differs from reducers
in the Pickler's dispatch_table, each of them being invoked for
objects of a specific type only.
This property comes in handy for classes: although most classes are
instances of the ``type`` metaclass, some of them can be instances
of other custom metaclasses (such as enum.EnumMeta for example). In
particular, the metaclass will likely not be known in advance, and
thus cannot be special-cased using an entry in the dispatch_table.
reducer_override, among other things, allows us to register a
reducer that will be called for any class, independently of its
type.
Notes:
* reducer_override has the priority over dispatch_table-registered
reducers.
* reducer_override can be used to fix other limitations of
cloudpickle for other types that suffered from type-specific
reducers, such as Exceptions. See
https://github.com/cloudpipe/cloudpickle/issues/248
"""
if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
return (
_create_parametrized_type_hint,
parametrized_type_hint_getinitargs(obj)
)
t = type(obj)
try:
is_anyclass = issubclass(t, type)
except TypeError: # t is not a class (old Boost; see SF #502085)
is_anyclass = False
if is_anyclass:
return _class_reduce(obj)
elif isinstance(obj, types.FunctionType):
return self._function_reduce(obj)
else:
# fallback to save_global, including the Pickler's
# dispatch_table
return NotImplemented
else:
# When reducer_override is not available, hack the pure-Python
# Pickler's types.FunctionType and type savers. Note: the type saver
# must override Pickler.save_global, because pickle.py contains a
# hard-coded call to save_global when pickling meta-classes.
dispatch = Pickler.dispatch.copy()
def _save_reduce_pickle5(self, func, args, state=None, listitems=None,
dictitems=None, state_setter=None, obj=None):
save = self.save
write = self.write
self.save_reduce(
func, args, state=None, listitems=listitems,
dictitems=dictitems, obj=obj
)
# backport of the Python 3.8 state_setter pickle operations
save(state_setter)
save(obj) # simple BINGET opcode as obj is already memoized.
save(state)
write(pickle.TUPLE2)
# Trigger a state_setter(obj, state) function call.
write(pickle.REDUCE)
# The purpose of state_setter is to carry-out an
# inplace modification of obj. We do not care about what the
# method might return, so its output is eventually removed from
# the stack.
write(pickle.POP)
def save_global(self, obj, name=None, pack=struct.pack):
"""
Save a "global".
The name of this method is somewhat misleading: all types get
dispatched here.
"""
if obj is type(None): # noqa
return self.save_reduce(type, (None,), obj=obj)
elif obj is type(Ellipsis):
return self.save_reduce(type, (Ellipsis,), obj=obj)
elif obj is type(NotImplemented):
return self.save_reduce(type, (NotImplemented,), obj=obj)
elif obj in _BUILTIN_TYPE_NAMES:
return self.save_reduce(
_builtin_type, (_BUILTIN_TYPE_NAMES[obj],), obj=obj)
if sys.version_info[:2] < (3, 7) and _is_parametrized_type_hint(obj): # noqa # pragma: no branch
# Parametrized typing constructs in Python < 3.7 are not
# compatible with type checks and ``isinstance`` semantics. For
# this reason, it is easier to detect them using a
# duck-typing-based check (``_is_parametrized_type_hint``) than
# to populate the Pickler's dispatch with type-specific savers.
self.save_reduce(
_create_parametrized_type_hint,
parametrized_type_hint_getinitargs(obj),
obj=obj
)
elif name is not None:
Pickler.save_global(self, obj, name=name)
elif not _should_pickle_by_reference(obj, name=name):
self._save_reduce_pickle5(*_dynamic_class_reduce(obj), obj=obj)
else:
Pickler.save_global(self, obj, name=name)
dispatch[type] = save_global
def save_function(self, obj, name=None):
""" Registered with the dispatch to handle all function types.
Determines what kind of function obj is (e.g. lambda, defined at
interactive prompt, etc) and handles the pickling appropriately.
"""
if _should_pickle_by_reference(obj, name=name):
return Pickler.save_global(self, obj, name=name)
elif PYPY and isinstance(obj.__code__, builtin_code_type):
return self.save_pypy_builtin_func(obj)
else:
return self._save_reduce_pickle5(
*self._dynamic_function_reduce(obj), obj=obj
)
def save_pypy_builtin_func(self, obj):
"""Save pypy equivalent of builtin functions.
PyPy does not have the concept of builtin-functions. Instead,
builtin-functions are simple function instances, but with a
builtin-code attribute.
Most of the time, builtin functions should be pickled by attribute.
But PyPy has flaky support for __qualname__, so some builtin
functions such as float.__new__ will be classified as dynamic. For
this reason only, we created this special routine. Because
builtin-functions are not expected to have closure or globals,
there is no additional hack (compared the one already implemented
in pickle) to protect ourselves from reference cycles. A simple
(reconstructor, newargs, obj.__dict__) tuple is save_reduced. Note
also that PyPy improved their support for __qualname__ in v3.6, so
this routing should be removed when cloudpickle supports only PyPy
3.6 and later.
"""
rv = (types.FunctionType, (obj.__code__, {}, obj.__name__,
obj.__defaults__, obj.__closure__),
obj.__dict__)
self.save_reduce(*rv, obj=obj)
dispatch[types.FunctionType] = save_function
import sys
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
from pickle5 import Pickler # noqa: F401
except ImportError:
import pickle # noqa: F401
# Use the Python pickler for old CPython versions
from pickle import _Pickler as Pickler # noqa: F401
else:
import pickle # noqa: F401
# Pickler will the C implementation in CPython and the Python
# implementation in PyPy
from pickle import Pickler # noqa: F401
"""
Utils for IR analysis
"""
import operator
from functools import reduce
from collections import namedtuple, defaultdict
from .controlflow import CFGraph
from numba.core import types, errors, ir, consts
from numba.misc import special
#
# Analysis related to variable lifetime
#
_use_defs_result = namedtuple('use_defs_result', 'usemap,defmap')
# other packages that define new nodes add calls for finding defs
# format: {type:function}
ir_extension_usedefs = {}
def compute_use_defs(blocks):
"""
Find variable use/def per block.
"""
var_use_map = {} # { block offset -> set of vars }
var_def_map = {} # { block offset -> set of vars }
for offset, ir_block in blocks.items():
var_use_map[offset] = use_set = set()
var_def_map[offset] = def_set = set()
for stmt in ir_block.body:
if type(stmt) in ir_extension_usedefs:
func = ir_extension_usedefs[type(stmt)]
func(stmt, use_set, def_set)
continue
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Inst):
rhs_set = set(var.name for var in stmt.value.list_vars())
elif isinstance(stmt.value, ir.Var):
rhs_set = set([stmt.value.name])
elif isinstance(stmt.value, (ir.Arg, ir.Const, ir.Global,
ir.FreeVar)):
rhs_set = ()
else:
raise AssertionError('unreachable', type(stmt.value))
# If lhs not in rhs of the assignment
if stmt.target.name not in rhs_set:
def_set.add(stmt.target.name)
for var in stmt.list_vars():
# do not include locally defined vars to use-map
if var.name not in def_set:
use_set.add(var.name)
return _use_defs_result(usemap=var_use_map, defmap=var_def_map)
def compute_live_map(cfg, blocks, var_use_map, var_def_map):
"""
Find variables that must be alive at the ENTRY of each block.
We use a simple fix-point algorithm that iterates until the set of
live variables is unchanged for each block.
"""
def fix_point_progress(dct):
"""Helper function to determine if a fix-point has been reached.
"""
return tuple(len(v) for v in dct.values())
def fix_point(fn, dct):
"""Helper function to run fix-point algorithm.
"""
old_point = None
new_point = fix_point_progress(dct)
while old_point != new_point:
fn(dct)
old_point = new_point
new_point = fix_point_progress(dct)
def def_reach(dct):
"""Find all variable definition reachable at the entry of a block
"""
for offset in var_def_map:
used_or_defined = var_def_map[offset] | var_use_map[offset]
dct[offset] |= used_or_defined
# Propagate to outgoing nodes
for out_blk, _ in cfg.successors(offset):
dct[out_blk] |= dct[offset]
def liveness(dct):
"""Find live variables.
Push var usage backward.
"""
for offset in dct:
# Live vars here
live_vars = dct[offset]
for inc_blk, _data in cfg.predecessors(offset):
# Reachable at the predecessor
reachable = live_vars & def_reach_map[inc_blk]
# But not defined in the predecessor
dct[inc_blk] |= reachable - var_def_map[inc_blk]
live_map = {}
for offset in blocks.keys():
live_map[offset] = set(var_use_map[offset])
def_reach_map = defaultdict(set)
fix_point(def_reach, def_reach_map)
fix_point(liveness, live_map)
return live_map
_dead_maps_result = namedtuple('dead_maps_result', 'internal,escaping,combined')
def compute_dead_maps(cfg, blocks, live_map, var_def_map):
"""
Compute the end-of-live information for variables.
`live_map` contains a mapping of block offset to all the living
variables at the ENTRY of the block.
"""
# The following three dictionaries will be
# { block offset -> set of variables to delete }
# all vars that should be deleted at the start of the successors
escaping_dead_map = defaultdict(set)
# all vars that should be deleted within this block
internal_dead_map = defaultdict(set)
# all vars that should be deleted after the function exit
exit_dead_map = defaultdict(set)
for offset, ir_block in blocks.items():
# live vars WITHIN the block will include all the locally
# defined variables
cur_live_set = live_map[offset] | var_def_map[offset]
# vars alive in the outgoing blocks
outgoing_live_map = dict((out_blk, live_map[out_blk])
for out_blk, _data in cfg.successors(offset))
# vars to keep alive for the terminator
terminator_liveset = set(v.name
for v in ir_block.terminator.list_vars())
# vars to keep alive in the successors
combined_liveset = reduce(operator.or_, outgoing_live_map.values(),
set())
# include variables used in terminator
combined_liveset |= terminator_liveset
# vars that are dead within the block because they are not
# propagated to any outgoing blocks
internal_set = cur_live_set - combined_liveset
internal_dead_map[offset] = internal_set
# vars that escape this block
escaping_live_set = cur_live_set - internal_set
for out_blk, new_live_set in outgoing_live_map.items():
# successor should delete the unused escaped vars
new_live_set = new_live_set | var_def_map[out_blk]
escaping_dead_map[out_blk] |= escaping_live_set - new_live_set
# if no outgoing blocks
if not outgoing_live_map:
# insert var used by terminator
exit_dead_map[offset] = terminator_liveset
# Verify that the dead maps cover all live variables
all_vars = reduce(operator.or_, live_map.values(), set())
internal_dead_vars = reduce(operator.or_, internal_dead_map.values(),
set())
escaping_dead_vars = reduce(operator.or_, escaping_dead_map.values(),
set())
exit_dead_vars = reduce(operator.or_, exit_dead_map.values(), set())
dead_vars = (internal_dead_vars | escaping_dead_vars | exit_dead_vars)
missing_vars = all_vars - dead_vars
if missing_vars:
# There are no exit points
if not cfg.exit_points():
# We won't be able to verify this
pass
else:
msg = 'liveness info missing for vars: {0}'.format(missing_vars)
raise RuntimeError(msg)
combined = dict((k, internal_dead_map[k] | escaping_dead_map[k])
for k in blocks)
return _dead_maps_result(internal=internal_dead_map,
escaping=escaping_dead_map,
combined=combined)
def compute_live_variables(cfg, blocks, var_def_map, var_dead_map):
"""
Compute the live variables at the beginning of each block
and at each yield point.
The ``var_def_map`` and ``var_dead_map`` indicates the variable defined
and deleted at each block, respectively.
"""
# live var at the entry per block
block_entry_vars = defaultdict(set)
def fix_point_progress():
return tuple(map(len, block_entry_vars.values()))
old_point = None
new_point = fix_point_progress()
# Propagate defined variables and still live the successors.
# (note the entry block automatically gets an empty set)
# Note: This is finding the actual available variables at the entry
# of each block. The algorithm in compute_live_map() is finding
# the variable that must be available at the entry of each block.
# This is top-down in the dataflow. The other one is bottom-up.
while old_point != new_point:
# We iterate until the result stabilizes. This is necessary
# because of loops in the graphself.
for offset in blocks:
# vars available + variable defined
avail = block_entry_vars[offset] | var_def_map[offset]
# subtract variables deleted
avail -= var_dead_map[offset]
# add ``avail`` to each successors
for succ, _data in cfg.successors(offset):
block_entry_vars[succ] |= avail
old_point = new_point
new_point = fix_point_progress()
return block_entry_vars
#
# Analysis related to controlflow
#
def compute_cfg_from_blocks(blocks):
cfg = CFGraph()
for k in blocks:
cfg.add_node(k)
for k, b in blocks.items():
term = b.terminator
for target in term.get_targets():
cfg.add_edge(k, target)
cfg.set_entry_point(min(blocks))
cfg.process()
return cfg
def find_top_level_loops(cfg):
"""
A generator that yields toplevel loops given a control-flow-graph
"""
blocks_in_loop = set()
# get loop bodies
for loop in cfg.loops().values():
insiders = set(loop.body) | set(loop.entries) | set(loop.exits)
insiders.discard(loop.header)
blocks_in_loop |= insiders
# find loop that is not part of other loops
for loop in cfg.loops().values():
if loop.header not in blocks_in_loop:
yield _fix_loop_exit(cfg, loop)
def _fix_loop_exit(cfg, loop):
"""
Fixes loop.exits for Py3.8 bytecode CFG changes.
This is to handle `break` inside loops.
"""
# Computes the common postdoms of exit nodes
postdoms = cfg.post_dominators()
exits = reduce(
operator.and_,
[postdoms[b] for b in loop.exits],
loop.exits,
)
if exits:
# Put the non-common-exits as body nodes
body = loop.body | loop.exits - exits
return loop._replace(exits=exits, body=body)
else:
return loop
# Used to describe a nullified condition in dead branch pruning
nullified = namedtuple('nullified', 'condition, taken_br, rewrite_stmt')
# Functions to manipulate IR
def dead_branch_prune(func_ir, called_args):
"""
Removes dead branches based on constant inference from function args.
This directly mutates the IR.
func_ir is the IR
called_args are the actual arguments with which the function is called
"""
from numba.core.ir_utils import (get_definition, guard, find_const,
GuardException)
DEBUG = 0
def find_branches(func_ir):
# find *all* branches
branches = []
for blk in func_ir.blocks.values():
branch_or_jump = blk.body[-1]
if isinstance(branch_or_jump, ir.Branch):
branch = branch_or_jump
pred = guard(get_definition, func_ir, branch.cond.name)
if pred is not None and getattr(pred, "op", None) == "call":
function = guard(get_definition, func_ir, pred.func)
if (function is not None and
isinstance(function, ir.Global) and
function.value is bool):
condition = guard(get_definition, func_ir, pred.args[0])
if condition is not None:
branches.append((branch, condition, blk))
return branches
def do_prune(take_truebr, blk):
keep = branch.truebr if take_truebr else branch.falsebr
# replace the branch with a direct jump
jmp = ir.Jump(keep, loc=branch.loc)
blk.body[-1] = jmp
return 1 if keep == branch.truebr else 0
def prune_by_type(branch, condition, blk, *conds):
# this prunes a given branch and fixes up the IR
# at least one needs to be a NoneType
lhs_cond, rhs_cond = conds
lhs_none = isinstance(lhs_cond, types.NoneType)
rhs_none = isinstance(rhs_cond, types.NoneType)
if lhs_none or rhs_none:
try:
take_truebr = condition.fn(lhs_cond, rhs_cond)
except Exception:
return False, None
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, lhs_cond, rhs_cond,
condition.fn)
taken = do_prune(take_truebr, blk)
return True, taken
return False, None
def prune_by_value(branch, condition, blk, *conds):
lhs_cond, rhs_cond = conds
try:
take_truebr = condition.fn(lhs_cond, rhs_cond)
except Exception:
return False, None
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, lhs_cond, rhs_cond, condition.fn)
taken = do_prune(take_truebr, blk)
return True, taken
def prune_by_predicate(branch, pred, blk):
try:
# Just to prevent accidents, whilst already guarded, ensure this
# is an ir.Const
if not isinstance(pred, (ir.Const, ir.FreeVar, ir.Global)):
raise TypeError('Expected constant Numba IR node')
take_truebr = bool(pred.value)
except TypeError:
return False, None
if DEBUG > 0:
kill = branch.falsebr if take_truebr else branch.truebr
print("Pruning %s" % kill, branch, pred)
taken = do_prune(take_truebr, blk)
return True, taken
class Unknown(object):
pass
def resolve_input_arg_const(input_arg_idx):
"""
Resolves an input arg to a constant (if possible)
"""
input_arg_ty = called_args[input_arg_idx]
# comparing to None?
if isinstance(input_arg_ty, types.NoneType):
return input_arg_ty
# is it a kwarg default
if isinstance(input_arg_ty, types.Omitted):
val = input_arg_ty.value
if isinstance(val, types.NoneType):
return val
elif val is None:
return types.NoneType('none')
# literal type, return the type itself so comparisons like `x == None`
# still work as e.g. x = types.int64 will never be None/NoneType so
# the branch can still be pruned
return getattr(input_arg_ty, 'literal_type', Unknown())
if DEBUG > 1:
print("before".center(80, '-'))
print(func_ir.dump())
phi2lbl = dict()
phi2asgn = dict()
for lbl, blk in func_ir.blocks.items():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
if isinstance(stmt.value, ir.Expr) and stmt.value.op == 'phi':
phi2lbl[stmt.value] = lbl
phi2asgn[stmt.value] = stmt
# This looks for branches where:
# at least one arg of the condition is in input args and const
# at least one an arg of the condition is a const
# if the condition is met it will replace the branch with a jump
branch_info = find_branches(func_ir)
# stores conditions that have no impact post prune
nullified_conditions = []
for branch, condition, blk in branch_info:
const_conds = []
if isinstance(condition, ir.Expr) and condition.op == 'binop':
prune = prune_by_value
for arg in [condition.lhs, condition.rhs]:
resolved_const = Unknown()
arg_def = guard(get_definition, func_ir, arg)
if isinstance(arg_def, ir.Arg):
# it's an e.g. literal argument to the function
resolved_const = resolve_input_arg_const(arg_def.index)
prune = prune_by_type
else:
# it's some const argument to the function, cannot use guard
# here as the const itself may be None
try:
resolved_const = find_const(func_ir, arg)
if resolved_const is None:
resolved_const = types.NoneType('none')
except GuardException:
pass
if not isinstance(resolved_const, Unknown):
const_conds.append(resolved_const)
# lhs/rhs are consts
if len(const_conds) == 2:
# prune the branch, switch the branch for an unconditional jump
prune_stat, taken = prune(branch, condition, blk, *const_conds)
if (prune_stat):
# add the condition to the list of nullified conditions
nullified_conditions.append(nullified(condition, taken,
True))
else:
# see if this is a branch on a constant value predicate
resolved_const = Unknown()
try:
pred_call = get_definition(func_ir, branch.cond)
resolved_const = find_const(func_ir, pred_call.args[0])
if resolved_const is None:
resolved_const = types.NoneType('none')
except GuardException:
pass
if not isinstance(resolved_const, Unknown):
prune_stat, taken = prune_by_predicate(branch, condition, blk)
if (prune_stat):
# add the condition to the list of nullified conditions
nullified_conditions.append(nullified(condition, taken,
False))
# 'ERE BE DRAGONS...
# It is the evaluation of the condition expression that often trips up type
# inference, so ideally it would be removed as it is effectively rendered
# dead by the unconditional jump if a branch was pruned. However, there may
# be references to the condition that exist in multiple places (e.g. dels)
# and we cannot run DCE here as typing has not taken place to give enough
# information to run DCE safely. Upshot of all this is the condition gets
# rewritten below into a benign const that typing will be happy with and DCE
# can remove it and its reference post typing when it is safe to do so
# (if desired). It is required that the const is assigned a value that
# indicates the branch taken as its mutated value would be read in the case
# of object mode fall back in place of the condition itself. For
# completeness the func_ir._definitions and ._consts are also updated to
# make the IR state self consistent.
deadcond = [x.condition for x in nullified_conditions]
for _, cond, blk in branch_info:
if cond in deadcond:
for x in blk.body:
if isinstance(x, ir.Assign) and x.value is cond:
# rewrite the condition as a true/false bit
nullified_info = nullified_conditions[deadcond.index(cond)]
# only do a rewrite of conditions, predicates need to retain
# their value as they may be used later.
if nullified_info.rewrite_stmt:
branch_bit = nullified_info.taken_br
x.value = ir.Const(branch_bit, loc=x.loc)
# update the specific definition to the new const
defns = func_ir._definitions[x.target.name]
repl_idx = defns.index(cond)
defns[repl_idx] = x.value
# Check post dominators of dead nodes from in the original CFG for use of
# vars that are being removed in the dead blocks which might be referred to
# by phi nodes.
#
# Multiple things to fix up:
#
# 1. Cases like:
#
# A A
# |\ |
# | B --> B
# |/ |
# C C
#
# i.e. the branch is dead but the block is still alive. In this case CFG
# simplification will fuse A-B-C and any phi in C can be updated as an
# direct assignment from the last assigned version in the dominators of the
# fused block.
#
# 2. Cases like:
#
# A A
# / \ |
# B C --> B
# \ / |
# D D
#
# i.e. the block C is dead. In this case the phis in D need updating to
# reflect the collapse of the phi condition. This should result in a direct
# assignment of the surviving version in B to the LHS of the phi in D.
new_cfg = compute_cfg_from_blocks(func_ir.blocks)
dead_blocks = new_cfg.dead_nodes()
# for all phis that are still in live blocks.
for phi, lbl in phi2lbl.items():
if lbl in dead_blocks:
continue
new_incoming = [x[0] for x in new_cfg.predecessors(lbl)]
if set(new_incoming) != set(phi.incoming_blocks):
# Something has changed in the CFG...
if len(new_incoming) == 1:
# There's now just one incoming. Replace the PHI node by a
# direct assignment
idx = phi.incoming_blocks.index(new_incoming[0])
phi2asgn[phi].value = phi.incoming_values[idx]
else:
# There's more than one incoming still, then look through the
# incoming and remove dead
ic_val_tmp = []
ic_blk_tmp = []
for ic_val, ic_blk in zip(phi.incoming_values,
phi.incoming_blocks):
if ic_blk in dead_blocks:
continue
else:
ic_val_tmp.append(ic_val)
ic_blk_tmp.append(ic_blk)
phi.incoming_values.clear()
phi.incoming_values.extend(ic_val_tmp)
phi.incoming_blocks.clear()
phi.incoming_blocks.extend(ic_blk_tmp)
# Remove dead blocks, this is safe as it relies on the CFG only.
for dead in dead_blocks:
del func_ir.blocks[dead]
# if conditions were nullified then consts were rewritten, update
if nullified_conditions:
func_ir._consts = consts.ConstantInference(func_ir)
if DEBUG > 1:
print("after".center(80, '-'))
print(func_ir.dump())
def rewrite_semantic_constants(func_ir, called_args):
"""
This rewrites values known to be constant by their semantics as ir.Const
nodes, this is to give branch pruning the best chance possible of killing
branches. An example might be rewriting len(tuple) as the literal length.
func_ir is the IR
called_args are the actual arguments with which the function is called
"""
DEBUG = 0
if DEBUG > 1:
print(("rewrite_semantic_constants: " +
func_ir.func_id.func_name).center(80, '-'))
print("before".center(80, '*'))
func_ir.dump()
def rewrite_statement(func_ir, stmt, new_val):
"""
Rewrites the stmt as a ir.Const new_val and fixes up the entries in
func_ir._definitions
"""
stmt.value = ir.Const(new_val, stmt.loc)
defns = func_ir._definitions[stmt.target.name]
repl_idx = defns.index(val)
defns[repl_idx] = stmt.value
def rewrite_array_ndim(val, func_ir, called_args):
# rewrite Array.ndim as const(ndim)
if getattr(val, 'op', None) == 'getattr':
if val.attr == 'ndim':
arg_def = guard(get_definition, func_ir, val.value)
if isinstance(arg_def, ir.Arg):
argty = called_args[arg_def.index]
if isinstance(argty, types.Array):
rewrite_statement(func_ir, stmt, argty.ndim)
def rewrite_tuple_len(val, func_ir, called_args):
# rewrite len(tuple) as const(len(tuple))
if getattr(val, 'op', None) == 'call':
func = guard(get_definition, func_ir, val.func)
if (func is not None and isinstance(func, ir.Global) and
getattr(func, 'value', None) is len):
(arg,) = val.args
arg_def = guard(get_definition, func_ir, arg)
if isinstance(arg_def, ir.Arg):
argty = called_args[arg_def.index]
if isinstance(argty, types.BaseTuple):
rewrite_statement(func_ir, stmt, argty.count)
elif (isinstance(arg_def, ir.Expr) and
arg_def.op == 'typed_getitem'):
argty = arg_def.dtype
if isinstance(argty, types.BaseTuple):
rewrite_statement(func_ir, stmt, argty.count)
from numba.core.ir_utils import get_definition, guard
for blk in func_ir.blocks.values():
for stmt in blk.body:
if isinstance(stmt, ir.Assign):
val = stmt.value
if isinstance(val, ir.Expr):
rewrite_array_ndim(val, func_ir, called_args)
rewrite_tuple_len(val, func_ir, called_args)
if DEBUG > 1:
print("after".center(80, '*'))
func_ir.dump()
print('-' * 80)
def find_literally_calls(func_ir, argtypes):
"""An analysis to find `numba.literally` call inside the given IR.
When an unsatisfied literal typing request is found, a `ForceLiteralArg`
exception is raised.
Parameters
----------
func_ir : numba.ir.FunctionIR
argtypes : Sequence[numba.types.Type]
The argument types.
"""
from numba.core import ir_utils
marked_args = set()
first_loc = {}
# Scan for literally calls
for blk in func_ir.blocks.values():
for assign in blk.find_exprs(op='call'):
var = ir_utils.guard(ir_utils.get_definition, func_ir, assign.func)
if isinstance(var, (ir.Global, ir.FreeVar)):
fnobj = var.value
else:
fnobj = ir_utils.guard(ir_utils.resolve_func_from_module,
func_ir, var)
if fnobj is special.literally:
# Found
[arg] = assign.args
defarg = func_ir.get_definition(arg)
if isinstance(defarg, ir.Arg):
argindex = defarg.index
marked_args.add(argindex)
first_loc.setdefault(argindex, assign.loc)
# Signal the dispatcher to force literal typing
for pos in marked_args:
query_arg = argtypes[pos]
do_raise = (isinstance(query_arg, types.InitialValue) and
query_arg.initial_value is None)
if do_raise:
loc = first_loc[pos]
raise errors.ForceLiteralArg(marked_args, loc=loc)
if not isinstance(query_arg, (types.Literal, types.InitialValue)):
loc = first_loc[pos]
raise errors.ForceLiteralArg(marked_args, loc=loc)
ir_extension_use_alloca = {}
def must_use_alloca(blocks):
"""
Analyzes a dictionary of blocks to find variables that must be
stack allocated with alloca. For each statement in the blocks,
determine if that statement requires certain variables to be
stack allocated. This function uses the extension point
ir_extension_use_alloca to allow other IR node types like parfors
to register to be processed by this analysis function. At the
moment, parfors are the only IR node types that may require
something to be stack allocated.
"""
use_alloca_vars = set()
for ir_block in blocks.values():
for stmt in ir_block.body:
if type(stmt) in ir_extension_use_alloca:
func = ir_extension_use_alloca[type(stmt)]
func(stmt, use_alloca_vars)
continue
return use_alloca_vars
"""
This module implements code highlighting of numba function annotations.
"""
from warnings import warn
warn("The pretty_annotate functionality is experimental and might change API",
FutureWarning)
def hllines(code, style):
try:
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import HtmlFormatter
except ImportError:
raise ImportError("please install the 'pygments' package")
pylex = PythonLexer()
"Given a code string, return a list of html-highlighted lines"
hf = HtmlFormatter(noclasses=True, style=style, nowrap=True)
res = highlight(code, pylex, hf)
return res.splitlines()
def htlines(code, style):
try:
from pygments import highlight
from pygments.lexers import PythonLexer
# TerminalFormatter does not support themes, Terminal256 should,
# but seem to not work.
from pygments.formatters import TerminalFormatter
except ImportError:
raise ImportError("please install the 'pygments' package")
pylex = PythonLexer()
"Given a code string, return a list of ANSI-highlighted lines"
hf = TerminalFormatter(style=style)
res = highlight(code, pylex, hf)
return res.splitlines()
def get_ansi_template():
try:
from jinja2 import Template
except ImportError:
raise ImportError("please install the 'jinja2' package")
return Template("""
{%- for func_key in func_data.keys() -%}
Function name: \x1b[34m{{func_data[func_key]['funcname']}}\x1b[39;49;00m
{%- if func_data[func_key]['filename'] -%}
{{'\n'}}In file: \x1b[34m{{func_data[func_key]['filename'] -}}\x1b[39;49;00m
{%- endif -%}
{{'\n'}}With signature: \x1b[34m{{func_key[1]}}\x1b[39;49;00m
{{- "\n" -}}
{%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%}
{{-'\n'}}{{ num}}: {{hc-}}
{%- if func_data[func_key]['ir_lines'][num] -%}
{%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %}
{{-'\n'}}--{{- ' '*func_data[func_key]['python_indent'][num]}}
{{- ' '*(func_data[func_key]['ir_indent'][num][loop.index0]+4)
}}{{ir_line }}\x1b[41m{{ir_line_type-}}\x1b[39;49;00m
{%- endfor -%}
{%- endif -%}
{%- endfor -%}
{%- endfor -%}
""")
return ansi_template
def get_html_template():
try:
from jinja2 import Template
except ImportError:
raise ImportError("please install the 'jinja2' package")
return Template("""
<html>
<head>
<style>
.annotation_table {
color: #000000;
font-family: monospace;
margin: 5px;
width: 100%;
}
/* override JupyterLab style */
.annotation_table td {
text-align: left;
background-color: transparent;
padding: 1px;
}
.annotation_table tbody tr:nth-child(even) {
background: white;
}
.annotation_table code
{
background-color: transparent;
white-space: normal;
}
/* End override JupyterLab style */
tr:hover {
background-color: rgba(92, 200, 249, 0.25);
}
td.object_tag summary ,
td.lifted_tag summary{
font-weight: bold;
display: list-item;
}
span.lifted_tag {
color: #00cc33;
}
span.object_tag {
color: #cc3300;
}
td.lifted_tag {
background-color: #cdf7d8;
}
td.object_tag {
background-color: #fef5c8;
}
code.ir_code {
color: grey;
font-style: italic;
}
.metadata {
border-bottom: medium solid black;
display: inline-block;
padding: 5px;
width: 100%;
}
.annotations {
padding: 5px;
}
.hidden {
display: none;
}
.buttons {
padding: 10px;
cursor: pointer;
}
</style>
</head>
<body>
{% for func_key in func_data.keys() %}
<div class="metadata">
Function name: {{func_data[func_key]['funcname']}}<br />
{% if func_data[func_key]['filename'] %}
in file: {{func_data[func_key]['filename']|escape}}<br />
{% endif %}
with signature: {{func_key[1]|e}}
</div>
<div class="annotations">
<table class="annotation_table tex2jax_ignore">
{%- for num, line, hl, hc in func_data[func_key]['pygments_lines'] -%}
{%- if func_data[func_key]['ir_lines'][num] %}
<tr><td style="text-align:left;" class="{{func_data[func_key]['python_tags'][num]}}">
<details>
<summary>
<code>
{{num}}:
{{'&nbsp;'*func_data[func_key]['python_indent'][num]}}{{hl}}
</code>
</summary>
<table class="annotation_table">
<tbody>
{%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %}
<tr class="ir_code">
<td style="text-align: left;"><code>
&nbsp;
{{- '&nbsp;'*func_data[func_key]['python_indent'][num]}}
{{ '&nbsp;'*func_data[func_key]['ir_indent'][num][loop.index0]}}{{ir_line|e -}}
<span class="object_tag">{{ir_line_type}}</span>
</code>
</td>
</tr>
{%- endfor -%}
</tbody>
</table>
</details>
</td></tr>
{% else -%}
<tr><td style="text-align:left; padding-left: 22px;" class="{{func_data[func_key]['python_tags'][num]}}">
<code>
{{num}}:
{{'&nbsp;'*func_data[func_key]['python_indent'][num]}}{{hl}}
</code>
</td></tr>
{%- endif -%}
{%- endfor -%}
</table>
</div>
{% endfor %}
</body>
</html>
""")
def reform_code(annotation):
"""
Extract the code from the Numba annotation datastructure.
Pygments can only highlight full multi-line strings, the Numba
annotation is list of single lines, with indentation removed.
"""
ident_dict = annotation['python_indent']
s= ''
for n,l in annotation['python_lines']:
s = s+' '*ident_dict[n]+l+'\n'
return s
class Annotate:
"""
Construct syntax highlighted annotation for a given jitted function:
Example:
>>> import numba
>>> from numba.pretty_annotate import Annotate
>>> @numba.jit
... def test(q):
... res = 0
... for i in range(q):
... res += i
... return res
...
>>> test(10)
45
>>> Annotate(test)
The last line will return an HTML and/or ANSI representation that will be
displayed accordingly in Jupyter/IPython.
Function annotations persist across compilation for newly encountered
type signatures and as a result annotations are shown for all signatures
by default.
Annotations for a specific signature can be shown by using the
``signature`` parameter.
>>> @numba.jit
... def add(x, y):
... return x + y
...
>>> add(1, 2)
3
>>> add(1.3, 5.7)
7.0
>>> add.signatures
[(int64, int64), (float64, float64)]
>>> Annotate(add, signature=add.signatures[1]) # annotation for (float64, float64)
"""
def __init__(self, function, signature=None, **kwargs):
style = kwargs.get('style', 'default')
if not function.signatures:
raise ValueError('function need to be jitted for at least one signature')
ann = function.get_annotation_info(signature=signature)
self.ann = ann
for k,v in ann.items():
res = hllines(reform_code(v), style)
rest = htlines(reform_code(v), style)
v['pygments_lines'] = [(a,b,c, d) for (a,b),c, d in zip(v['python_lines'], res, rest)]
def _repr_html_(self):
return get_html_template().render(func_data=self.ann)
def __repr__(self):
return get_ansi_template().render(func_data=self.ann)
<html>
<head>
<style>
.annotation_table {
color: #000000;
font-family: monospace;
margin: 5px;
width: 100%;
}
/* override JupyterLab style */
.annotation_table td {
text-align: left;
background-color: transparent;
padding: 1px;
}
.annotation_table code
{
background-color: transparent;
white-space: normal;
}
/* End override JupyterLab style */
tr:hover {
background-color: rgba(92, 200, 249, 0.25);
}
td.object_tag summary ,
td.lifted_tag summary{
font-weight: bold;
display: list-item;
}
span.lifted_tag {
color: #00cc33;
}
span.object_tag {
color: #cc3300;
}
td.lifted_tag {
background-color: #cdf7d8;
}
td.object_tag {
background-color: #ffd3d3;
}
code.ir_code {
color: grey;
font-style: italic;
}
.metadata {
border-bottom: medium solid black;
display: inline-block;
padding: 5px;
width: 100%;
}
.annotations {
padding: 5px;
}
.hidden {
display: none;
}
.buttons {
padding: 10px;
cursor: pointer;
}
</style>
</head>
<body>
{% for func_key in func_data.keys() %}
{% set loop1 = loop %}
<div class="metadata">
Function name: {{func_data[func_key]['funcname']}}<br />
in file: {{func_data[func_key]['filename']}}<br />
with signature: {{func_key[1]|e}}
</div>
<div class="annotations">
<table class="annotation_table tex2jax_ignore">
{%- for num, line in func_data[func_key]['python_lines'] -%}
{%- if func_data[func_key]['ir_lines'][num] %}
<tr><td class="{{func_data[func_key]['python_tags'][num]}}">
<details>
<summary>
<code>
{{num}}:
{{func_data[func_key]['python_indent'][num]}}{{line|e}}
</code>
</summary>
<table class="annotation_table">
<tbody>
{%- for ir_line, ir_line_type in func_data[func_key]['ir_lines'][num] %}
<tr class="ir_code func{{loop1.index0}}_ir">
<td><code>&nbsp;
{{- func_data[func_key]['python_indent'][num]}}
{{func_data[func_key]['ir_indent'][num][loop.index0]}}{{ir_line|e -}}
<span class="object_tag">{{ir_line_type}}</span>
</code>
</td>
</tr>
{%- endfor -%}
</tbody>
</table>
</details>
</td></tr>
{% else -%}
<tr><td style=" padding-left: 22px;" class="{{func_data[func_key]['python_tags'][num]}}">
<code>
{{num}}:
{{func_data[func_key]['python_indent'][num]}}{{line|e}}
</code>
</td></tr>
{%- endif -%}
{%- endfor -%}
</table>
</div>
<br /><br /><br />
{% endfor %}
</body>
</html>
from collections import defaultdict, OrderedDict
from collections.abc import Mapping
from contextlib import closing
import copy
import inspect
import os
import re
import sys
import textwrap
from io import StringIO
import numba.core.dispatcher
from numba.core import ir
class SourceLines(Mapping):
def __init__(self, func):
try:
lines, startno = inspect.getsourcelines(func)
except OSError:
self.lines = ()
self.startno = 0
else:
self.lines = textwrap.dedent(''.join(lines)).splitlines()
self.startno = startno
def __getitem__(self, lineno):
try:
return self.lines[lineno - self.startno].rstrip()
except IndexError:
return ''
def __iter__(self):
return iter((self.startno + i) for i in range(len(self.lines)))
def __len__(self):
return len(self.lines)
@property
def avail(self):
return bool(self.lines)
class TypeAnnotation(object):
# func_data dict stores annotation data for all functions that are
# compiled. We store the data in the TypeAnnotation class since a new
# TypeAnnotation instance is created for each function that is compiled.
# For every function that is compiled, we add the type annotation data to
# this dict and write the html annotation file to disk (rewrite the html
# file for every function since we don't know if this is the last function
# to be compiled).
func_data = OrderedDict()
def __init__(self, func_ir, typemap, calltypes, lifted, lifted_from,
args, return_type, html_output=None):
self.func_id = func_ir.func_id
self.blocks = func_ir.blocks
self.typemap = typemap
self.calltypes = calltypes
self.filename = func_ir.loc.filename
self.linenum = str(func_ir.loc.line)
self.signature = str(args) + ' -> ' + str(return_type)
# lifted loop information
self.lifted = lifted
self.num_lifted_loops = len(lifted)
# If this is a lifted loop function that is being compiled, lifted_from
# points to annotation data from function that this loop lifted function
# was lifted from. This is used to stick lifted loop annotations back
# into original function.
self.lifted_from = lifted_from
def prepare_annotations(self):
# Prepare annotations
groupedinst = defaultdict(list)
found_lifted_loop = False
#for blkid, blk in self.blocks.items():
for blkid in sorted(self.blocks.keys()):
blk = self.blocks[blkid]
groupedinst[blk.loc.line].append("label %s" % blkid)
for inst in blk.body:
lineno = inst.loc.line
if isinstance(inst, ir.Assign):
if found_lifted_loop:
atype = 'XXX Lifted Loop XXX'
found_lifted_loop = False
elif (isinstance(inst.value, ir.Expr) and
inst.value.op == 'call'):
atype = self.calltypes[inst.value]
elif (isinstance(inst.value, ir.Const) and
isinstance(inst.value.value, numba.core.dispatcher.LiftedLoop)):
atype = 'XXX Lifted Loop XXX'
found_lifted_loop = True
else:
# TODO: fix parfor lowering so that typemap is valid.
atype = self.typemap.get(inst.target.name, "<missing>")
aline = "%s = %s :: %s" % (inst.target, inst.value, atype)
elif isinstance(inst, ir.SetItem):
atype = self.calltypes[inst]
aline = "%s :: %s" % (inst, atype)
else:
aline = "%s" % inst
groupedinst[lineno].append(" %s" % aline)
return groupedinst
def annotate(self):
source = SourceLines(self.func_id.func)
# if not source.avail:
# return "Source code unavailable"
groupedinst = self.prepare_annotations()
# Format annotations
io = StringIO()
with closing(io):
if source.avail:
print("# File: %s" % self.filename, file=io)
for num in source:
srcline = source[num]
ind = _getindent(srcline)
print("%s# --- LINE %d --- " % (ind, num), file=io)
for inst in groupedinst[num]:
print('%s# %s' % (ind, inst), file=io)
print(file=io)
print(srcline, file=io)
print(file=io)
if self.lifted:
print("# The function contains lifted loops", file=io)
for loop in self.lifted:
print("# Loop at line %d" % loop.get_source_location(),
file=io)
print("# Has %d overloads" % len(loop.overloads),
file=io)
for cres in loop.overloads.values():
print(cres.type_annotation, file=io)
else:
print("# Source code unavailable", file=io)
for num in groupedinst:
for inst in groupedinst[num]:
print('%s' % (inst,), file=io)
print(file=io)
return io.getvalue()
def html_annotate(self, outfile):
# ensure that annotation information is assembled
self.annotate_raw()
# make a deep copy ahead of the pending mutations
func_data = copy.deepcopy(self.func_data)
key = 'python_indent'
for this_func in func_data.values():
if key in this_func:
idents = {}
for line, amount in this_func[key].items():
idents[line] = '&nbsp;' * amount
this_func[key] = idents
key = 'ir_indent'
for this_func in func_data.values():
if key in this_func:
idents = {}
for line, ir_id in this_func[key].items():
idents[line] = ['&nbsp;' * amount for amount in ir_id]
this_func[key] = idents
try:
from jinja2 import Template
except ImportError:
raise ImportError("please install the 'jinja2' package")
root = os.path.join(os.path.dirname(__file__))
template_filename = os.path.join(root, 'template.html')
with open(template_filename, 'r') as template:
html = template.read()
template = Template(html)
rendered = template.render(func_data=func_data)
outfile.write(rendered)
def annotate_raw(self):
"""
This returns "raw" annotation information i.e. it has no output format
specific markup included.
"""
python_source = SourceLines(self.func_id.func)
ir_lines = self.prepare_annotations()
line_nums = [num for num in python_source]
lifted_lines = [l.get_source_location() for l in self.lifted]
def add_ir_line(func_data, line):
line_str = line.strip()
line_type = ''
if line_str.endswith('pyobject'):
line_str = line_str.replace('pyobject', '')
line_type = 'pyobject'
func_data['ir_lines'][num].append((line_str, line_type))
indent_len = len(_getindent(line))
func_data['ir_indent'][num].append(indent_len)
func_key = (self.func_id.filename + ':' + str(self.func_id.firstlineno + 1),
self.signature)
if self.lifted_from is not None and self.lifted_from[1]['num_lifted_loops'] > 0:
# This is a lifted loop function that is being compiled. Get the
# numba ir for lines in loop function to use for annotating
# original python function that the loop was lifted from.
func_data = self.lifted_from[1]
for num in line_nums:
if num not in ir_lines.keys():
continue
func_data['ir_lines'][num] = []
func_data['ir_indent'][num] = []
for line in ir_lines[num]:
add_ir_line(func_data, line)
if line.strip().endswith('pyobject'):
func_data['python_tags'][num] = 'object_tag'
# If any pyobject line is found, make sure original python
# line that was marked as a lifted loop start line is tagged
# as an object line instead. Lifted loop start lines should
# only be marked as lifted loop lines if the lifted loop
# was successfully compiled in nopython mode.
func_data['python_tags'][self.lifted_from[0]] = 'object_tag'
# We're done with this lifted loop, so decrement lifted loop counter.
# When lifted loop counter hits zero, that means we're ready to write
# out annotations to html file.
self.lifted_from[1]['num_lifted_loops'] -= 1
elif func_key not in TypeAnnotation.func_data.keys():
TypeAnnotation.func_data[func_key] = {}
func_data = TypeAnnotation.func_data[func_key]
for i, loop in enumerate(self.lifted):
# Make sure that when we process each lifted loop function later,
# we'll know where it originally came from.
loop.lifted_from = (lifted_lines[i], func_data)
func_data['num_lifted_loops'] = self.num_lifted_loops
func_data['filename'] = self.filename
func_data['funcname'] = self.func_id.func_name
func_data['python_lines'] = []
func_data['python_indent'] = {}
func_data['python_tags'] = {}
func_data['ir_lines'] = {}
func_data['ir_indent'] = {}
for num in line_nums:
func_data['python_lines'].append((num, python_source[num].strip()))
indent_len = len(_getindent(python_source[num]))
func_data['python_indent'][num] = indent_len
func_data['python_tags'][num] = ''
func_data['ir_lines'][num] = []
func_data['ir_indent'][num] = []
for line in ir_lines[num]:
add_ir_line(func_data, line)
if num in lifted_lines:
func_data['python_tags'][num] = 'lifted_tag'
elif line.strip().endswith('pyobject'):
func_data['python_tags'][num] = 'object_tag'
return self.func_data
def __str__(self):
return self.annotate()
re_longest_white_prefix = re.compile(r'^\s*')
def _getindent(text):
m = re_longest_white_prefix.match(text)
if not m:
return ''
else:
return ' ' * len(m.group(0))
from collections import defaultdict
import copy
import sys
from itertools import permutations, takewhile
from contextlib import contextmanager
from functools import cached_property
from llvmlite import ir as llvmir
from llvmlite.ir import Constant
import llvmlite.binding as ll
from numba.core import types, utils, datamodel, debuginfo, funcdesc, config, cgutils, imputils
from numba.core import event, errors, targetconfig
from numba import _dynfunc, _helperlib
from numba.core.compiler_lock import global_compiler_lock
from numba.core.pythonapi import PythonAPI
from numba.core.imputils import (user_function, user_generator,
builtin_registry, impl_ret_borrowed,
RegistryLoader)
from numba.cpython import builtins
GENERIC_POINTER = llvmir.PointerType(llvmir.IntType(8))
PYOBJECT = GENERIC_POINTER
void_ptr = GENERIC_POINTER
class OverloadSelector(object):
"""
An object matching an actual signature against a registry of formal
signatures and choosing the best candidate, if any.
In the current implementation:
- a "signature" is a tuple of type classes or type instances
- the "best candidate" is the most specific match
"""
def __init__(self):
# A list of (formal args tuple, value)
self.versions = []
self._cache = {}
def find(self, sig):
out = self._cache.get(sig)
if out is None:
out = self._find(sig)
self._cache[sig] = out
return out
def _find(self, sig):
candidates = self._select_compatible(sig)
if candidates:
return candidates[self._best_signature(candidates)]
else:
raise errors.NumbaNotImplementedError(f'{self}, {sig}')
def _select_compatible(self, sig):
"""
Select all compatible signatures and their implementation.
"""
out = {}
for ver_sig, impl in self.versions:
if self._match_arglist(ver_sig, sig):
out[ver_sig] = impl
return out
def _best_signature(self, candidates):
"""
Returns the best signature out of the candidates
"""
ordered, genericity = self._sort_signatures(candidates)
# check for ambiguous signatures
if len(ordered) > 1:
firstscore = genericity[ordered[0]]
same = list(takewhile(lambda x: genericity[x] == firstscore,
ordered))
if len(same) > 1:
msg = ["{n} ambiguous signatures".format(n=len(same))]
for sig in same:
msg += ["{0} => {1}".format(sig, candidates[sig])]
raise errors.NumbaTypeError('\n'.join(msg))
return ordered[0]
def _sort_signatures(self, candidates):
"""
Sort signatures in ascending level of genericity.
Returns a 2-tuple:
* ordered list of signatures
* dictionary containing genericity scores
"""
# score by genericity
genericity = defaultdict(int)
for this, other in permutations(candidates.keys(), r=2):
matched = self._match_arglist(formal_args=this, actual_args=other)
if matched:
# genericity score +1 for every another compatible signature
genericity[this] += 1
# order candidates in ascending level of genericity
ordered = sorted(candidates.keys(), key=lambda x: genericity[x])
return ordered, genericity
def _match_arglist(self, formal_args, actual_args):
"""
Returns True if the signature is "matching".
A formal signature is "matching" if the actual signature matches exactly
or if the formal signature is a compatible generic signature.
"""
# normalize VarArg
if formal_args and isinstance(formal_args[-1], types.VarArg):
ndiff = len(actual_args) - len(formal_args) + 1
formal_args = formal_args[:-1] + (formal_args[-1].dtype,) * ndiff
if len(formal_args) != len(actual_args):
return False
for formal, actual in zip(formal_args, actual_args):
if not self._match(formal, actual):
return False
return True
def _match(self, formal, actual):
if formal == actual:
# formal argument matches actual arguments
return True
elif types.Any == formal:
# formal argument is any
return True
elif isinstance(formal, type) and issubclass(formal, types.Type):
if isinstance(actual, type) and issubclass(actual, formal):
# formal arg is a type class and actual arg is a subclass
return True
elif isinstance(actual, formal):
# formal arg is a type class of which actual arg is an instance
return True
def append(self, value, sig):
"""
Add a formal signature and its associated value.
"""
assert isinstance(sig, tuple), (value, sig)
self.versions.append((sig, value))
self._cache.clear()
@utils.runonce
def _load_global_helpers():
"""
Execute once to install special symbols into the LLVM symbol table.
"""
# This is Py_None's real C name
ll.add_symbol("_Py_NoneStruct", id(None))
# Add Numba C helper functions
for c_helpers in (_helperlib.c_helpers, _dynfunc.c_helpers):
for py_name, c_address in c_helpers.items():
c_name = "numba_" + py_name
ll.add_symbol(c_name, c_address)
# Add all built-in exception classes
for obj in utils.builtins.__dict__.values():
if isinstance(obj, type) and issubclass(obj, BaseException):
ll.add_symbol("PyExc_%s" % (obj.__name__), id(obj))
class BaseContext(object):
"""
Notes on Structure
------------------
Most objects are lowered as plain-old-data structure in the generated
llvm. They are passed around by reference (a pointer to the structure).
Only POD structure can live across function boundaries by copying the
data.
"""
# True if the target requires strict alignment
# Causes exception to be raised if the record members are not aligned.
strict_alignment = False
# Force powi implementation as math.pow call
implement_powi_as_math_call = False
implement_pow_as_math_call = False
# Emit Debug info
enable_debuginfo = False
DIBuilder = debuginfo.DIBuilder
# Bound checking
@property
def enable_boundscheck(self):
if config.BOUNDSCHECK is not None:
return config.BOUNDSCHECK
return self._boundscheck
@enable_boundscheck.setter
def enable_boundscheck(self, value):
self._boundscheck = value
# NRT
enable_nrt = False
# Auto parallelization
auto_parallel = False
# PYCC
aot_mode = False
# Error model for various operations (only FP exceptions currently)
error_model = None
# Whether dynamic globals (CPU runtime addresses) is allowed
allow_dynamic_globals = False
# Fast math flags
fastmath = False
# python execution environment
environment = None
# the function descriptor
fndesc = None
def __init__(self, typing_context, target):
_load_global_helpers()
self.address_size = utils.MACHINE_BITS
self.typing_context = typing_context
from numba.core.target_extension import target_registry
self.target_name = target
self.target = target_registry[target]
# A mapping of installed registries to their loaders
self._registries = {}
# Declarations loaded from registries and other sources
self._defns = defaultdict(OverloadSelector)
self._getattrs = defaultdict(OverloadSelector)
self._setattrs = defaultdict(OverloadSelector)
self._casts = OverloadSelector()
self._get_constants = OverloadSelector()
# Other declarations
self._generators = {}
self.special_ops = {}
self.cached_internal_func = {}
self._pid = None
self._codelib_stack = []
self._boundscheck = False
self.data_model_manager = datamodel.default_manager
# Initialize
self.init()
def init(self):
"""
For subclasses to add initializer
"""
def refresh(self):
"""
Refresh context with new declarations from known registries.
Useful for third-party extensions.
"""
# load target specific registries
self.load_additional_registries()
# Populate the builtin registry, this has to happen after loading
# additional registries as some of the "additional" registries write
# their implementations into the builtin_registry and would be missed if
# this ran first.
self.install_registry(builtin_registry)
# Also refresh typing context, since @overload declarations can
# affect it.
self.typing_context.refresh()
def load_additional_registries(self):
"""
Load target-specific registries. Can be overridden by subclasses.
"""
def mangler(self, name, types, *, abi_tags=(), uid=None):
"""
Perform name mangling.
"""
return funcdesc.default_mangler(name, types, abi_tags=abi_tags, uid=uid)
def get_env_name(self, fndesc):
"""Get the environment name given a FunctionDescriptor.
Use this instead of the ``fndesc.env_name`` so that the target-context
can provide necessary mangling of the symbol to meet ABI requirements.
"""
return fndesc.env_name
def declare_env_global(self, module, envname):
"""Declare the Environment pointer as a global of the module.
The pointer is initialized to NULL. It must be filled by the runtime
with the actual address of the Env before the associated function
can be executed.
Parameters
----------
module :
The LLVM Module
envname : str
The name of the global variable.
"""
if envname not in module.globals:
gv = llvmir.GlobalVariable(module, cgutils.voidptr_t, name=envname)
gv.linkage = 'common'
gv.initializer = cgutils.get_null_value(gv.type.pointee)
return module.globals[envname]
def get_arg_packer(self, fe_args):
return datamodel.ArgPacker(self.data_model_manager, fe_args)
def get_data_packer(self, fe_types):
return datamodel.DataPacker(self.data_model_manager, fe_types)
@property
def target_data(self):
raise NotImplementedError
@cached_property
def nonconst_module_attrs(self):
"""
All module attrs are constant for targets using BaseContext.
"""
return tuple()
@cached_property
def nrt(self):
from numba.core.runtime.context import NRTContext
return NRTContext(self, self.enable_nrt)
def subtarget(self, **kws):
obj = copy.copy(self) # shallow copy
for k, v in kws.items():
if not hasattr(obj, k):
raise NameError("unknown option {0!r}".format(k))
setattr(obj, k, v)
if obj.codegen() is not self.codegen():
# We can't share functions across different codegens
obj.cached_internal_func = {}
return obj
def install_registry(self, registry):
"""
Install a *registry* (a imputils.Registry instance) of function
and attribute implementations.
"""
try:
loader = self._registries[registry]
except KeyError:
loader = RegistryLoader(registry)
self._registries[registry] = loader
self.insert_func_defn(loader.new_registrations('functions'))
self._insert_getattr_defn(loader.new_registrations('getattrs'))
self._insert_setattr_defn(loader.new_registrations('setattrs'))
self._insert_cast_defn(loader.new_registrations('casts'))
self._insert_get_constant_defn(loader.new_registrations('constants'))
def insert_func_defn(self, defns):
for impl, func, sig in defns:
self._defns[func].append(impl, sig)
def _insert_getattr_defn(self, defns):
for impl, attr, sig in defns:
self._getattrs[attr].append(impl, sig)
def _insert_setattr_defn(self, defns):
for impl, attr, sig in defns:
self._setattrs[attr].append(impl, sig)
def _insert_cast_defn(self, defns):
for impl, sig in defns:
self._casts.append(impl, sig)
def _insert_get_constant_defn(self, defns):
for impl, sig in defns:
self._get_constants.append(impl, sig)
def insert_user_function(self, func, fndesc, libs=()):
impl = user_function(fndesc, libs)
self._defns[func].append(impl, impl.signature)
def insert_generator(self, genty, gendesc, libs=()):
assert isinstance(genty, types.Generator)
impl = user_generator(gendesc, libs)
self._generators[genty] = gendesc, impl
def remove_user_function(self, func):
"""
Remove user function *func*.
KeyError is raised if the function isn't known to us.
"""
del self._defns[func]
def get_external_function_type(self, fndesc):
argtypes = [self.get_argument_type(aty)
for aty in fndesc.argtypes]
# don't wrap in pointer
restype = self.get_argument_type(fndesc.restype)
fnty = llvmir.FunctionType(restype, argtypes)
return fnty
def declare_function(self, module, fndesc):
fnty = self.call_conv.get_function_type(fndesc.restype, fndesc.argtypes)
fn = cgutils.get_or_insert_function(module, fnty, fndesc.mangled_name)
self.call_conv.decorate_function(fn, fndesc.args, fndesc.argtypes, noalias=fndesc.noalias)
if fndesc.inline:
fn.attributes.add('alwaysinline')
# alwaysinline overrides optnone
fn.attributes.discard('noinline')
fn.attributes.discard('optnone')
return fn
def declare_external_function(self, module, fndesc):
fnty = self.get_external_function_type(fndesc)
fn = cgutils.get_or_insert_function(module, fnty, fndesc.mangled_name)
assert fn.is_declaration
for ak, av in zip(fndesc.args, fn.args):
av.name = "arg.%s" % ak
return fn
def insert_const_string(self, mod, string):
"""
Insert constant *string* (a str object) into module *mod*.
"""
stringtype = GENERIC_POINTER
name = ".const.%s" % string
text = cgutils.make_bytearray(string.encode("utf-8") + b"\x00")
gv = self.insert_unique_const(mod, name, text)
return Constant.bitcast(gv, stringtype)
def insert_const_bytes(self, mod, bytes, name=None):
"""
Insert constant *byte* (a `bytes` object) into module *mod*.
"""
stringtype = GENERIC_POINTER
name = ".bytes.%s" % (name or hash(bytes))
text = cgutils.make_bytearray(bytes)
gv = self.insert_unique_const(mod, name, text)
return Constant.bitcast(gv, stringtype)
def insert_unique_const(self, mod, name, val):
"""
Insert a unique internal constant named *name*, with LLVM value
*val*, into module *mod*.
"""
try:
gv = mod.get_global(name)
except KeyError:
return cgutils.global_constant(mod, name, val)
else:
return gv
def get_argument_type(self, ty):
return self.data_model_manager[ty].get_argument_type()
def get_return_type(self, ty):
return self.data_model_manager[ty].get_return_type()
def get_data_type(self, ty):
"""
Get a LLVM data representation of the Numba type *ty* that is safe
for storage. Record data are stored as byte array.
The return value is a llvmlite.ir.Type object, or None if the type
is an opaque pointer (???).
"""
return self.data_model_manager[ty].get_data_type()
def get_value_type(self, ty):
return self.data_model_manager[ty].get_value_type()
def pack_value(self, builder, ty, value, ptr, align=None):
"""
Pack value into the array storage at *ptr*.
If *align* is given, it is the guaranteed alignment for *ptr*
(by default, the standard ABI alignment).
"""
dataval = self.data_model_manager[ty].as_data(builder, value)
builder.store(dataval, ptr, align=align)
def unpack_value(self, builder, ty, ptr, align=None):
"""
Unpack value from the array storage at *ptr*.
If *align* is given, it is the guaranteed alignment for *ptr*
(by default, the standard ABI alignment).
"""
dm = self.data_model_manager[ty]
return dm.load_from_data_pointer(builder, ptr, align)
def get_constant_generic(self, builder, ty, val):
"""
Return a LLVM constant representing value *val* of Numba type *ty*.
"""
try:
impl = self._get_constants.find((ty,))
return impl(self, builder, ty, val)
except NotImplementedError:
raise NotImplementedError("Cannot lower constant of type '%s'" % (ty,))
def get_constant(self, ty, val):
"""
Same as get_constant_generic(), but without specifying *builder*.
Works only for simple types.
"""
# HACK: pass builder=None to preserve get_constant() API
return self.get_constant_generic(None, ty, val)
def get_constant_undef(self, ty):
lty = self.get_value_type(ty)
return Constant(lty, llvmir.Undefined)
def get_constant_null(self, ty):
lty = self.get_value_type(ty)
return Constant(lty, None)
def get_function(self, fn, sig, _firstcall=True):
"""
Return the implementation of function *fn* for signature *sig*.
The return value is a callable with the signature (builder, args).
"""
assert sig is not None
sig = sig.as_function()
if isinstance(fn, types.Callable):
key = fn.get_impl_key(sig)
overloads = self._defns[key]
else:
key = fn
overloads = self._defns[key]
try:
return _wrap_impl(overloads.find(sig.args), self, sig)
except errors.NumbaNotImplementedError:
pass
if isinstance(fn, types.Type):
# It's a type instance => try to find a definition for the type class
try:
return self.get_function(type(fn), sig)
except NotImplementedError:
# Raise exception for the type instance, for a better error message
pass
# Automatically refresh the context to load new registries if we are
# calling the first time.
if _firstcall:
self.refresh()
return self.get_function(fn, sig, _firstcall=False)
raise NotImplementedError("No definition for lowering %s%s" % (key, sig))
def get_generator_desc(self, genty):
"""
"""
return self._generators[genty][0]
def get_generator_impl(self, genty):
"""
"""
res = self._generators[genty][1]
self.add_linking_libs(getattr(res, 'libs', ()))
return res
def get_bound_function(self, builder, obj, ty):
assert self.get_value_type(ty) == obj.type
return obj
def get_getattr(self, typ, attr):
"""
Get the getattr() implementation for the given type and attribute name.
The return value is a callable with the signature
(context, builder, typ, val, attr).
"""
const_attr = (typ, attr) not in self.nonconst_module_attrs
is_module = isinstance(typ, types.Module)
if is_module and const_attr:
# Implement getattr for module-level globals that we treat as
# constants.
# XXX We shouldn't have to retype this
attrty = self.typing_context.resolve_module_constants(typ, attr)
if attrty is None or isinstance(attrty, types.Dummy):
# No implementation required for dummies (functions, modules...),
# which are dealt with later
return None
else:
pyval = getattr(typ.pymod, attr)
def imp(context, builder, typ, val, attr):
llval = self.get_constant_generic(builder, attrty, pyval)
return impl_ret_borrowed(context, builder, attrty, llval)
return imp
# Lookup specific getattr implementation for this type and attribute
overloads = self._getattrs[attr]
try:
return overloads.find((typ,))
except errors.NumbaNotImplementedError:
pass
# Lookup generic getattr implementation for this type
overloads = self._getattrs[None]
try:
return overloads.find((typ,))
except errors.NumbaNotImplementedError:
pass
raise NotImplementedError("No definition for lowering %s.%s" % (typ, attr))
def get_setattr(self, attr, sig):
"""
Get the setattr() implementation for the given attribute name
and signature.
The return value is a callable with the signature (builder, args).
"""
assert len(sig.args) == 2
typ = sig.args[0]
valty = sig.args[1]
def wrap_setattr(impl):
def wrapped(builder, args):
return impl(self, builder, sig, args, attr)
return wrapped
# Lookup specific setattr implementation for this type and attribute
overloads = self._setattrs[attr]
try:
return wrap_setattr(overloads.find((typ, valty)))
except errors.NumbaNotImplementedError:
pass
# Lookup generic setattr implementation for this type
overloads = self._setattrs[None]
try:
return wrap_setattr(overloads.find((typ, valty)))
except errors.NumbaNotImplementedError:
pass
raise NotImplementedError("No definition for lowering %s.%s = %s"
% (typ, attr, valty))
def get_argument_value(self, builder, ty, val):
"""
Argument representation to local value representation
"""
return self.data_model_manager[ty].from_argument(builder, val)
def get_returned_value(self, builder, ty, val):
"""
Return value representation to local value representation
"""
return self.data_model_manager[ty].from_return(builder, val)
def get_return_value(self, builder, ty, val):
"""
Local value representation to return type representation
"""
return self.data_model_manager[ty].as_return(builder, val)
def get_value_as_argument(self, builder, ty, val):
"""Prepare local value representation as argument type representation
"""
return self.data_model_manager[ty].as_argument(builder, val)
def get_value_as_data(self, builder, ty, val):
return self.data_model_manager[ty].as_data(builder, val)
def get_data_as_value(self, builder, ty, val):
return self.data_model_manager[ty].from_data(builder, val)
def pair_first(self, builder, val, ty):
"""
Extract the first element of a heterogeneous pair.
"""
pair = self.make_helper(builder, ty, val)
return pair.first
def pair_second(self, builder, val, ty):
"""
Extract the second element of a heterogeneous pair.
"""
pair = self.make_helper(builder, ty, val)
return pair.second
def cast(self, builder, val, fromty, toty):
"""
Cast a value of type *fromty* to type *toty*.
This implements implicit conversions as can happen due to the
granularity of the Numba type system, or lax Python semantics.
"""
if fromty == toty or toty == types.Any:
return val
try:
impl = self._casts.find((fromty, toty))
return impl(self, builder, fromty, toty, val)
except errors.NumbaNotImplementedError:
raise errors.NumbaNotImplementedError(
"Cannot cast %s to %s: %s" % (fromty, toty, val))
def generic_compare(self, builder, key, argtypes, args):
"""
Compare the given LLVM values of the given Numba types using
the comparison *key* (e.g. '=='). The values are first cast to
a common safe conversion type.
"""
at, bt = argtypes
av, bv = args
ty = self.typing_context.unify_types(at, bt)
assert ty is not None
cav = self.cast(builder, av, at, ty)
cbv = self.cast(builder, bv, bt, ty)
fnty = self.typing_context.resolve_value_type(key)
# the sig is homogeneous in the unified casted type
cmpsig = fnty.get_call_type(self.typing_context, (ty, ty), {})
cmpfunc = self.get_function(fnty, cmpsig)
self.add_linking_libs(getattr(cmpfunc, 'libs', ()))
return cmpfunc(builder, (cav, cbv))
def make_optional_none(self, builder, valtype):
optval = self.make_helper(builder, types.Optional(valtype))
optval.valid = cgutils.false_bit
return optval._getvalue()
def make_optional_value(self, builder, valtype, value):
optval = self.make_helper(builder, types.Optional(valtype))
optval.valid = cgutils.true_bit
optval.data = value
return optval._getvalue()
def is_true(self, builder, typ, val):
"""
Return the truth value of a value of the given Numba type.
"""
fnty = self.typing_context.resolve_value_type(bool)
sig = fnty.get_call_type(self.typing_context, (typ,), {})
impl = self.get_function(fnty, sig)
return impl(builder, (val,))
def get_c_value(self, builder, typ, name, dllimport=False):
"""
Get a global value through its C-accessible *name*, with the given
LLVM type.
If *dllimport* is true, the symbol will be marked as imported
from a DLL (necessary for AOT compilation under Windows).
"""
module = builder.function.module
try:
gv = module.globals[name]
except KeyError:
gv = cgutils.add_global_variable(module, typ, name)
if dllimport and self.aot_mode and sys.platform == 'win32':
gv.storage_class = "dllimport"
return gv
def call_external_function(self, builder, callee, argtys, args):
args = [self.get_value_as_argument(builder, ty, arg)
for ty, arg in zip(argtys, args)]
retval = builder.call(callee, args)
return retval
def get_function_pointer_type(self, typ):
return self.data_model_manager[typ].get_data_type()
def call_function_pointer(self, builder, funcptr, args, cconv=None):
return builder.call(funcptr, args, cconv=cconv)
def print_string(self, builder, text):
mod = builder.module
cstring = GENERIC_POINTER
fnty = llvmir.FunctionType(llvmir.IntType(32), [cstring])
puts = cgutils.get_or_insert_function(mod, fnty, "puts")
return builder.call(puts, [text])
def debug_print(self, builder, text):
mod = builder.module
cstr = self.insert_const_string(mod, str(text))
self.print_string(builder, cstr)
def printf(self, builder, format_string, *args):
mod = builder.module
if isinstance(format_string, str):
cstr = self.insert_const_string(mod, format_string)
else:
cstr = format_string
fnty = llvmir.FunctionType(llvmir.IntType(32), (GENERIC_POINTER,), var_arg=True)
fn = cgutils.get_or_insert_function(mod, fnty, "printf")
return builder.call(fn, (cstr,) + tuple(args))
def get_struct_type(self, struct):
"""
Get the LLVM struct type for the given Structure class *struct*.
"""
fields = [self.get_value_type(v) for _, v in struct._fields]
return llvmir.LiteralStructType(fields)
def get_dummy_value(self):
return Constant(self.get_dummy_type(), None)
def get_dummy_type(self):
return GENERIC_POINTER
def _compile_subroutine_no_cache(self, builder, impl, sig, locals={},
flags=None):
"""
Invoke the compiler to compile a function to be used inside a
nopython function, but without generating code to call that
function.
Note this context's flags are not inherited.
"""
# Compile
from numba.core import compiler
with global_compiler_lock:
codegen = self.codegen()
library = codegen.create_library(impl.__name__)
if flags is None:
cstk = targetconfig.ConfigStack()
flags = compiler.Flags()
if cstk:
tls_flags = cstk.top()
if tls_flags.is_set("nrt") and tls_flags.nrt:
flags.nrt = True
flags.no_compile = True
flags.no_cpython_wrapper = True
flags.no_cfunc_wrapper = True
cres = compiler.compile_internal(self.typing_context, self,
library,
impl, sig.args,
sig.return_type, flags,
locals=locals)
# Allow inlining the function inside callers.
self.active_code_library.add_linking_library(cres.library)
return cres
def compile_subroutine(self, builder, impl, sig, locals={}, flags=None,
caching=True):
"""
Compile the function *impl* for the given *sig* (in nopython mode).
Return an instance of CompileResult.
If *caching* evaluates True, the function keeps the compiled function
for reuse in *.cached_internal_func*.
"""
cache_key = (impl.__code__, sig, type(self.error_model))
if not caching:
cached = None
else:
if impl.__closure__:
# XXX This obviously won't work if a cell's value is
# unhashable.
cache_key += tuple(c.cell_contents for c in impl.__closure__)
cached = self.cached_internal_func.get(cache_key)
if cached is None:
cres = self._compile_subroutine_no_cache(builder, impl, sig,
locals=locals,
flags=flags)
self.cached_internal_func[cache_key] = cres
cres = self.cached_internal_func[cache_key]
# Allow inlining the function inside callers.
self.active_code_library.add_linking_library(cres.library)
return cres
def compile_internal(self, builder, impl, sig, args, locals={}):
"""
Like compile_subroutine(), but also call the function with the given
*args*.
"""
cres = self.compile_subroutine(builder, impl, sig, locals)
return self.call_internal(builder, cres.fndesc, sig, args)
def call_internal(self, builder, fndesc, sig, args):
"""
Given the function descriptor of an internally compiled function,
emit a call to that function with the given arguments.
"""
status, res = self.call_internal_no_propagate(builder, fndesc, sig, args)
with cgutils.if_unlikely(builder, status.is_error):
self.call_conv.return_status_propagate(builder, status)
res = imputils.fix_returning_optional(self, builder, sig, status, res)
return res
def call_internal_no_propagate(self, builder, fndesc, sig, args):
"""Similar to `.call_internal()` but does not handle or propagate
the return status automatically.
"""
# Add call to the generated function
llvm_mod = builder.module
fn = self.declare_function(llvm_mod, fndesc)
status, res = self.call_conv.call_function(builder, fn, sig.return_type,
sig.args, args)
return status, res
def call_unresolved(self, builder, name, sig, args):
"""
Insert a function call to an unresolved symbol with the given *name*.
Note: this is used for recursive call.
In the mutual recursion case::
@njit
def foo():
... # calls bar()
@njit
def bar():
... # calls foo()
foo()
When foo() is called, the compilation of bar() is fully completed
(codegen'ed and loaded) before foo() is. Since MCJIT's eager compilation
doesn't allow loading modules with declare-only functions (which is
needed for foo() in bar()), the call_unresolved injects a global
variable that the "linker" can update even after the module is loaded by
MCJIT. The linker would allocate space for the global variable before
the bar() module is loaded. When later foo() module is defined, it will
update bar()'s reference to foo().
The legacy lazy JIT and the new ORC JIT would allow a declare-only
function be used in a module as long as it is defined by the time of its
first use.
"""
# Insert an unresolved reference to the function being called.
codegen = self.codegen()
fnty = self.call_conv.get_function_type(sig.return_type, sig.args)
fn = codegen.insert_unresolved_ref(builder, fnty, name)
# Normal call sequence
status, res = self.call_conv.call_function(builder, fn, sig.return_type,
sig.args, args)
with cgutils.if_unlikely(builder, status.is_error):
self.call_conv.return_status_propagate(builder, status)
res = imputils.fix_returning_optional(self, builder, sig, status, res)
return res
def get_executable(self, func, fndesc, env):
raise NotImplementedError
def get_python_api(self, builder):
return PythonAPI(self, builder)
def sentry_record_alignment(self, rectyp, attr):
"""
Assumes offset starts from a properly aligned location
"""
if self.strict_alignment:
offset = rectyp.offset(attr)
elemty = rectyp.typeof(attr)
if isinstance(elemty, types.NestedArray):
# For a NestedArray we need to consider the data type of
# elements of the array for alignment, not the array structure
# itself
elemty = elemty.dtype
align = self.get_abi_alignment(self.get_data_type(elemty))
if offset % align:
msg = "{rec}.{attr} of type {type} is not aligned".format(
rec=rectyp, attr=attr, type=elemty)
raise TypeError(msg)
def get_helper_class(self, typ, kind='value'):
"""
Get a helper class for the given *typ*.
"""
# XXX handle all types: complex, array, etc.
# XXX should it be a method on the model instead? this would allow a default kind...
return cgutils.create_struct_proxy(typ, kind)
def _make_helper(self, builder, typ, value=None, ref=None, kind='value'):
cls = self.get_helper_class(typ, kind)
return cls(self, builder, value=value, ref=ref)
def make_helper(self, builder, typ, value=None, ref=None):
"""
Get a helper object to access the *typ*'s members,
for the given value or reference.
"""
return self._make_helper(builder, typ, value, ref, kind='value')
def make_data_helper(self, builder, typ, ref=None):
"""
As make_helper(), but considers the value as stored in memory,
rather than a live value.
"""
return self._make_helper(builder, typ, ref=ref, kind='data')
def make_array(self, typ):
from numba.np import arrayobj
return arrayobj.make_array(typ)
def populate_array(self, arr, **kwargs):
"""
Populate array structure.
"""
from numba.np import arrayobj
return arrayobj.populate_array(arr, **kwargs)
def make_complex(self, builder, typ, value=None):
"""
Get a helper object to access the given complex numbers' members.
"""
assert isinstance(typ, types.Complex), typ
return self.make_helper(builder, typ, value)
def make_tuple(self, builder, typ, values):
"""
Create a tuple of the given *typ* containing the *values*.
"""
tup = self.get_constant_undef(typ)
for i, val in enumerate(values):
tup = builder.insert_value(tup, val, i)
return tup
def make_constant_array(self, builder, typ, ary):
"""
Create an array structure reifying the given constant array.
A low-level contiguous array constant is created in the LLVM IR.
"""
datatype = self.get_data_type(typ.dtype)
# don't freeze ary of non-contig or bigger than 1MB
size_limit = 10**6
if (self.allow_dynamic_globals and
(typ.layout not in 'FC' or ary.nbytes > size_limit)):
# get pointer from the ary
dataptr = ary.ctypes.data
data = self.add_dynamic_addr(builder, dataptr, info=str(type(dataptr)))
rt_addr = self.add_dynamic_addr(builder, id(ary), info=str(type(ary)))
else:
# Handle data: reify the flattened array in "C" or "F" order as a
# global array of bytes.
flat = ary.flatten(order=typ.layout)
# Note: we use `bytearray(flat.data)` instead of `bytearray(flat)` to
# workaround issue #1850 which is due to numpy issue #3147
consts = cgutils.create_constant_array(llvmir.IntType(8), bytearray(flat.data))
data = cgutils.global_constant(builder, ".const.array.data", consts)
# Ensure correct data alignment (issue #1933)
data.align = self.get_abi_alignment(datatype)
# No reference to parent ndarray
rt_addr = None
# Handle shape
llintp = self.get_value_type(types.intp)
shapevals = [self.get_constant(types.intp, s) for s in ary.shape]
cshape = cgutils.create_constant_array(llintp, shapevals)
# Handle strides
stridevals = [self.get_constant(types.intp, s) for s in ary.strides]
cstrides = cgutils.create_constant_array(llintp, stridevals)
# Create array structure
cary = self.make_array(typ)(self, builder)
intp_itemsize = self.get_constant(types.intp, ary.dtype.itemsize)
self.populate_array(cary,
data=builder.bitcast(data, cary.data.type),
shape=cshape,
strides=cstrides,
itemsize=intp_itemsize,
parent=rt_addr,
meminfo=None)
return cary._getvalue()
def add_dynamic_addr(self, builder, intaddr, info):
"""
Returns dynamic address as a void pointer `i8*`.
Internally, a global variable is added to inform the lowerer about
the usage of dynamic addresses. Caching will be disabled.
"""
assert self.allow_dynamic_globals, "dyn globals disabled in this target"
assert isinstance(intaddr, int), 'dyn addr not of int type'
mod = builder.module
llvoidptr = self.get_value_type(types.voidptr)
addr = self.get_constant(types.uintp, intaddr).inttoptr(llvoidptr)
# Use a unique name by embedding the address value
symname = 'numba.dynamic.globals.{:x}'.format(intaddr)
gv = cgutils.add_global_variable(mod, llvoidptr, symname)
# Use linkonce linkage to allow merging with other GV of the same name.
# And, avoid optimization from assuming its value.
gv.linkage = 'linkonce'
gv.initializer = addr
return builder.load(gv)
def get_abi_sizeof(self, ty):
"""
Get the ABI size of LLVM type *ty*.
"""
assert isinstance(ty, llvmir.Type), "Expected LLVM type"
return ty.get_abi_size(self.target_data)
def get_abi_alignment(self, ty):
"""
Get the ABI alignment of LLVM type *ty*.
"""
assert isinstance(ty, llvmir.Type), "Expected LLVM type"
return ty.get_abi_alignment(self.target_data)
def get_preferred_array_alignment(context, ty):
"""
Get preferred array alignment for Numba type *ty*.
"""
# AVX prefers 32-byte alignment
return 32
def post_lowering(self, mod, library):
"""Run target specific post-lowering transformation here.
"""
def create_module(self, name):
"""Create a LLVM module
The default implementation in BaseContext always raises a
``NotImplementedError`` exception. Subclasses should implement
this method.
"""
raise NotImplementedError
@property
def active_code_library(self):
"""Get the active code library
"""
return self._codelib_stack[-1]
@contextmanager
def push_code_library(self, lib):
"""Push the active code library for the context
"""
self._codelib_stack.append(lib)
try:
yield
finally:
self._codelib_stack.pop()
def add_linking_libs(self, libs):
"""Add iterable of linking libraries to the *active_code_library*.
"""
colib = self.active_code_library
for lib in libs:
colib.add_linking_library(lib)
def get_ufunc_info(self, ufunc_key):
"""Get the ufunc implementation for a given ufunc object.
The default implementation in BaseContext always raises a
``NotImplementedError`` exception. Subclasses may raise ``KeyError``
to signal that the given ``ufunc_key`` is not available.
Parameters
----------
ufunc_key : NumPy ufunc
Returns
-------
res : dict[str, callable]
A mapping of a NumPy ufunc type signature to a lower-level
implementation.
"""
raise NotImplementedError(f"{self} does not support ufunc")
class _wrap_impl(object):
"""
A wrapper object to call an implementation function with some predefined
(context, signature) arguments.
The wrapper also forwards attribute queries, which is important.
"""
def __init__(self, imp, context, sig):
self._callable = _wrap_missing_loc(imp)
self._imp = self._callable()
self._context = context
self._sig = sig
def __call__(self, builder, args, loc=None):
res = self._imp(self._context, builder, self._sig, args, loc=loc)
self._context.add_linking_libs(getattr(self, 'libs', ()))
return res
def __getattr__(self, item):
return getattr(self._imp, item)
def __repr__(self):
return "<wrapped %s>" % repr(self._callable)
def _has_loc(fn):
"""Does function *fn* take ``loc`` argument?
"""
sig = utils.pysignature(fn)
return 'loc' in sig.parameters
class _wrap_missing_loc(object):
def __init__(self, fn):
self.func = fn # store this to help with debug
def __call__(self):
"""Wrap function for missing ``loc`` keyword argument.
Otherwise, return the original *fn*.
"""
fn = self.func
if not _has_loc(fn):
def wrapper(*args, **kwargs):
kwargs.pop('loc') # drop unused loc
return fn(*args, **kwargs)
# Copy the following attributes from the wrapped.
# Following similar implementation as functools.wraps but
# ignore attributes if not available (i.e fix py2.7)
attrs = '__name__', 'libs'
for attr in attrs:
try:
val = getattr(fn, attr)
except AttributeError:
pass
else:
setattr(wrapper, attr, val)
return wrapper
else:
return fn
def __repr__(self):
return "<wrapped %s>" % self.func
@utils.runonce
def _initialize_llvm_lock_event():
"""Initial event triggers for LLVM lock
"""
def enter_fn():
event.start_event("numba:llvm_lock")
def exit_fn():
event.end_event("numba:llvm_lock")
ll.ffi.register_lock_callback(enter_fn, exit_fn)
_initialize_llvm_lock_event()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment