Unverified Commit e9440acb authored by Jinjing Zhou's avatar Jinjing Zhou Committed by GitHub
Browse files

[TF] TF backend fix and new logic to choose backend (#1393)



* TF backend fix and new logic to choose backend

* fix

* fix

* fix

* fix

* fix backend

* fix

* dlpack alignment

* add flag

* flag

* lint

* lint

* remove unused

* several fixes
Co-authored-by: default avatarMinjie Wang <wmjlyjemaine@gmail.com>
parent 4b4186f8
...@@ -406,3 +406,4 @@ DGL_REGISTER_GLOBAL("_GetDeviceAttr") ...@@ -406,3 +406,4 @@ DGL_REGISTER_GLOBAL("_GetDeviceAttr")
DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret); DeviceAPIManager::Get(ctx)->GetAttr(ctx, kind, ret);
} }
}); });
...@@ -350,10 +350,24 @@ int DGLArrayFromDLPack(DLManagedTensor* from, ...@@ -350,10 +350,24 @@ int DGLArrayFromDLPack(DLManagedTensor* from,
API_END(); API_END();
} }
int DGLArrayToDLPack(DGLArrayHandle from, inline bool is_aligned(const void* ptr, std::uintptr_t alignment) noexcept {
DLManagedTensor** out) { auto iptr = reinterpret_cast<std::uintptr_t>(ptr);
return !(iptr % alignment);
}
int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment) {
API_BEGIN(); API_BEGIN();
*out = NDArray::Internal::ToDLPack(reinterpret_cast<NDArray::Container*>(from)); auto* nd_container = reinterpret_cast<NDArray::Container*>(from);
DLTensor* nd = &(nd_container->dl_tensor);
if (alignment != 0 && !is_aligned(nd->data, alignment)) {
std::vector<int64_t> shape_vec(nd->shape, nd->shape + nd->ndim);
NDArray copy_ndarray = NDArray::Empty(shape_vec, nd->dtype, nd->ctx);
copy_ndarray.CopyFrom(nd);
*out = copy_ndarray.ToDLPack();
} else {
*out = NDArray::Internal::ToDLPack(nd_container);
}
API_END(); API_END();
} }
......
...@@ -6,8 +6,7 @@ import importlib ...@@ -6,8 +6,7 @@ import importlib
import sys import sys
import numpy as np import numpy as np
mod_name = os.environ.get('DGLBACKEND', 'pytorch').lower() mod = importlib.import_module('.%s' % backend_name, __name__)
mod = importlib.import_module('.%s' % mod_name, __name__)
thismod = sys.modules[__name__] thismod = sys.modules[__name__]
for api in backend_unittest.__dict__.keys(): for api in backend_unittest.__dict__.keys():
...@@ -17,7 +16,6 @@ for api in backend_unittest.__dict__.keys(): ...@@ -17,7 +16,6 @@ for api in backend_unittest.__dict__.keys():
# Tensor APIs used in unit tests MUST be supported across all backends # Tensor APIs used in unit tests MUST be supported across all backends
globals()[api] = mod.__dict__[api] globals()[api] = mod.__dict__[api]
# Tensor creation with default dtype and context # Tensor creation with default dtype and context
_zeros = zeros _zeros = zeros
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment