Unverified Commit 9e30b806 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #799 from InfiniTensor/issue/798

issue/798 - fix operator device handling
parents fb5e36d2 3720127c
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
namespace infinicore { namespace infinicore {
namespace context { namespace context {
void setDevice(Device device, bool force_cpu = false); void setDevice(Device device);
Device getDevice(); Device getDevice();
size_t getDeviceCount(Device::Type type); size_t getDeviceCount(Device::Type type);
......
...@@ -36,6 +36,10 @@ public: ...@@ -36,6 +36,10 @@ public:
return cache_vector[device_index]; return cache_vector[device_index];
} }
BaseCache &getCache(Device device) {
return getCache(device.getType(), device.getIndex());
}
void setCapacity(size_t capacity) { void setCapacity(size_t capacity) {
capacity_ = capacity; capacity_ = capacity;
for (auto &vec : caches_) { for (auto &vec : caches_) {
......
...@@ -23,13 +23,13 @@ def get_device_count(device_type): ...@@ -23,13 +23,13 @@ def get_device_count(device_type):
return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type) return _infinicore.get_device_count(infinicore.device(device_type)._underlying.type)
def set_device(device, force_cpu=False): def set_device(device):
"""Set the current active device. """Set the current active device.
Args: Args:
device: The device to set as active device: The device to set as active
""" """
_infinicore.set_device(device._underlying, force_cpu) _infinicore.set_device(device._underlying)
def sync_stream(): def sync_stream():
......
...@@ -33,15 +33,11 @@ Runtime *ContextImpl::getCpuRuntime() { ...@@ -33,15 +33,11 @@ Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get(); return runtime_table_[int(Device::Type::CPU)][0].get();
} }
void ContextImpl::setDevice(Device device, bool force_cpu) { void ContextImpl::setDevice(Device device) {
if (device == getCurrentRuntime()->device()) { if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set. // Do nothing if the device is already set.
return; return;
} }
if (device == Device(Device::Type::CPU, 0) && !force_cpu) {
// if not forced, no need to switch to CPU device runtime
return;
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) { if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
// Lazy initialization of runtime if never set before. // Lazy initialization of runtime if never set before.
...@@ -87,8 +83,8 @@ ContextImpl::ContextImpl() { ...@@ -87,8 +83,8 @@ ContextImpl::ContextImpl() {
namespace context { namespace context {
void setDevice(Device device, bool force_cpu) { void setDevice(Device device) {
ContextImpl::singleton().setDevice(device, force_cpu); ContextImpl::singleton().setDevice(device);
} }
Device getDevice() { Device getDevice() {
......
...@@ -21,7 +21,7 @@ public: ...@@ -21,7 +21,7 @@ public:
Runtime *getCpuRuntime(); Runtime *getCpuRuntime();
void setDevice(Device, bool force_cpu = false); void setDevice(Device);
size_t getDeviceCount(Device::Type type); size_t getDeviceCount(Device::Type type);
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAddDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b) { void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a); size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopAddDescriptor_t desc = nullptr; infiniopAddDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(c->device()), &desc, context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAttentionDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopAttentionDescriptor_t> caches(
void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) { void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos) {
size_t seed = hash_combine(out, q, k, v, k_cache, v_cache, pos); size_t seed = hash_combine(out, q, k, v, k_cache, v_cache, pos);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopAttentionDescriptor_t desc = nullptr; infiniopAttentionDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
context::getInfiniopHandle(out->device()), &desc, context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k->desc(), v->desc(), out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos)); k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc); cache.put(seed, desc);
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopCausalSoftmaxDescriptor_t> caches(
void calculate(Tensor output, Tensor input) { void calculate(Tensor output, Tensor input) {
size_t seed = hash_combine(output, input); size_t seed = hash_combine(output, input);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopCausalSoftmaxDescriptor_t desc = nullptr; infiniopCausalSoftmaxDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(output->device()), &desc, context::getInfiniopHandle(device), &desc,
output->desc(), input->desc())); output->desc(), input->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopGemmDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) { void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
size_t seed = hash_combine(c, b, a, alpha, beta); size_t seed = hash_combine(c, b, a, alpha, beta);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopGemmDescriptor_t desc = nullptr; infiniopGemmDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(c->device()), &desc, context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopMulDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b) { void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a); size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopMulDescriptor_t desc = nullptr; infiniopMulDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(c->device()), &desc, context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -25,17 +25,15 @@ static void calculate( ...@@ -25,17 +25,15 @@ static void calculate(
// cache per (result desc + logits desc) on device // cache per (result desc + logits desc) on device
size_t seed = hash_combine(indices, logits); size_t seed = hash_combine(indices, logits);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopRandomSampleDescriptor_t desc = nullptr; infiniopRandomSampleDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRandomSampleDescriptor(
context::getInfiniopHandle(indices->device()), &desc, context::getInfiniopHandle(device), &desc,
indices->desc(), logits->desc())); indices->desc(), logits->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -18,16 +18,14 @@ thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches( ...@@ -18,16 +18,14 @@ thread_local common::OpCache<size_t, infiniopRearrangeDescriptor_t> caches(
void calculate(Tensor y, Tensor x) { void calculate(Tensor y, Tensor x) {
size_t seed = hash_combine(y, x); size_t seed = hash_combine(y, x);
auto device_type = y->device().getType(); auto device = context::getDevice();
auto device_index = y->device().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopRearrangeDescriptor_t desc = nullptr; infiniopRearrangeDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc())); INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(device), &desc, y->desc(), x->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
desc = *desc_opt; desc = *desc_opt;
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopRMSNormDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopRMSNormDescriptor_t> caches(
void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) { void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
size_t seed = hash_combine(y, x, weight, epsilon); size_t seed = hash_combine(y, x, weight, epsilon);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopRMSNormDescriptor_t desc = nullptr; infiniopRMSNormDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
context::getInfiniopHandle(y->device()), &desc, context::getInfiniopHandle(device), &desc,
y->desc(), x->desc(), weight->desc(), epsilon)); y->desc(), x->desc(), weight->desc(), epsilon));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -33,16 +33,15 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s ...@@ -33,16 +33,15 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s
size_t key = hash_combine(x_out, x, pos, sin_cache, cos_cache); size_t key = hash_combine(x_out, x, pos, sin_cache, cos_cache);
hash_combine(key, std::hash<int>()(static_cast<int>(infiniop_algo))); hash_combine(key, std::hash<int>()(static_cast<int>(infiniop_algo)));
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(key); auto desc_opt = cache.get(key);
infiniopRoPEDescriptor_t desc = nullptr; infiniopRoPEDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
context::getInfiniopHandle(x_out->device()), &desc, context::getInfiniopHandle(device), &desc,
x_out->desc(), x->desc(), x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(), pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo)); infiniop_algo));
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSiluDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSiluDescriptor_t> caches(
void calculate(Tensor output, Tensor input) { void calculate(Tensor output, Tensor input) {
size_t seed = hash_combine(output, input); size_t seed = hash_combine(output, input);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopSiluDescriptor_t desc = nullptr; infiniopSiluDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor(
context::getInfiniopHandle(output->device()), &desc, context::getInfiniopHandle(device), &desc,
output->desc(), input->desc())); output->desc(), input->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches( ...@@ -18,17 +18,15 @@ thread_local common::OpCache<size_t, infiniopSwiGLUDescriptor_t> caches(
void calculate(Tensor c, Tensor a, Tensor b) { void calculate(Tensor c, Tensor a, Tensor b) {
size_t seed = hash_combine(c, b, a); size_t seed = hash_combine(c, b, a);
auto device_type = context::getDevice().getType(); auto device = context::getDevice();
auto device_index = context::getDevice().getIndex(); auto &cache = caches.getCache(device);
auto &cache = caches.getCache(device_type, device_index);
auto desc_opt = cache.get(seed); auto desc_opt = cache.get(seed);
infiniopSwiGLUDescriptor_t desc = nullptr; infiniopSwiGLUDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(c->device()), &desc, context::getInfiniopHandle(device), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -16,8 +16,7 @@ inline void bind(py::module &m) { ...@@ -16,8 +16,7 @@ inline void bind(py::module &m) {
py::arg("device_type")); py::arg("device_type"));
m.def("set_device", &setDevice, m.def("set_device", &setDevice,
"Set the current active device", "Set the current active device",
py::arg("device"), py::arg("device"));
py::arg("force_cpu"));
// Stream and handle management // Stream and handle management
m.def("get_stream", &getStream, "Get the current stream"); m.def("get_stream", &getStream, "Get the current stream");
......
...@@ -31,6 +31,7 @@ void TensorImpl::copy_from(Tensor src) { ...@@ -31,6 +31,7 @@ void TensorImpl::copy_from(Tensor src) {
// Use nbytes() to get the actual tensor size, not the full memory size // Use nbytes() to get the actual tensor size, not the full memory size
size_t copy_size = std::min(this->nbytes(), src->nbytes()); size_t copy_size = std::min(this->nbytes(), src->nbytes());
if (this->device().getType() == Device::Type::CPU) { if (this->device().getType() == Device::Type::CPU) {
context::setDevice(src->device());
if (this->is_contiguous()) { if (this->is_contiguous()) {
context::memcpyD2H(this->data(), src->data(), copy_size); context::memcpyD2H(this->data(), src->data(), copy_size);
} else { } else {
...@@ -39,7 +40,7 @@ void TensorImpl::copy_from(Tensor src) { ...@@ -39,7 +40,7 @@ void TensorImpl::copy_from(Tensor src) {
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src); op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src);
} }
} else if (src->device().getType() == Device::Type::CPU) { } else if (src->device().getType() == Device::Type::CPU) {
context::setDevice(this->device());
if (this->is_contiguous()) { if (this->is_contiguous()) {
context::memcpyH2D(this->data(), src->data(), copy_size); context::memcpyH2D(this->data(), src->data(), copy_size);
} else { } else {
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
thread_local infiniDevice_t CURRENT_DEVICE_TYPE = INFINI_DEVICE_CPU; thread_local infiniDevice_t CURRENT_DEVICE_TYPE = INFINI_DEVICE_CPU;
thread_local int CURRENT_DEVICE_ID = 0; thread_local int CURRENT_DEVICE_ID = 0;
thread_local infiniDevice_t PREVIOUS_NON_CPU_DEVICE_TYPE = INFINI_DEVICE_TYPE_COUNT;
thread_local int PREVIOUS_NON_CPU_DEVICE_ID = 0;
__C infiniStatus_t infinirtInit() { __C infiniStatus_t infinirtInit() {
#ifdef ENABLE_ASCEND_API #ifdef ENABLE_ASCEND_API
...@@ -96,6 +98,16 @@ __C infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count) { ...@@ -96,6 +98,16 @@ __C infiniStatus_t infinirtGetDeviceCount(infiniDevice_t device, int *count) {
} }
__C infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id) { __C infiniStatus_t infinirtSetDevice(infiniDevice_t device, int device_id) {
bool skip_set = CURRENT_DEVICE_TYPE == INFINI_DEVICE_CPU && device == PREVIOUS_NON_CPU_DEVICE_TYPE && device_id == PREVIOUS_NON_CPU_DEVICE_ID;
if (CURRENT_DEVICE_TYPE != INFINI_DEVICE_CPU) {
PREVIOUS_NON_CPU_DEVICE_TYPE = CURRENT_DEVICE_TYPE;
PREVIOUS_NON_CPU_DEVICE_ID = CURRENT_DEVICE_ID;
}
if (skip_set) {
CURRENT_DEVICE_TYPE = device;
CURRENT_DEVICE_ID = device_id;
return INFINI_STATUS_SUCCESS;
}
INFINIRT_CALL_DEVICE_API_AND(device, setDevice, (device_id), INFINIRT_CALL_DEVICE_API_AND(device, setDevice, (device_id),
{ CURRENT_DEVICE_TYPE = device; { CURRENT_DEVICE_TYPE = device;
CURRENT_DEVICE_ID = device_id; }); CURRENT_DEVICE_ID = device_id; });
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment