Commit 16854aed authored by wooway777's avatar wooway777
Browse files

issue/591 - fix operator context mismatch

parent a311e9c8
......@@ -16,7 +16,7 @@ Device getDevice();
size_t getDeviceCount(Device::Type type);
infinirtStream_t getStream();
infiniopHandle_t getInfiniopHandle();
infiniopHandle_t getInfiniopHandle(Device device);
void syncStream();
void syncDevice();
......
......@@ -99,7 +99,13 @@ infinirtStream_t getStream() {
return ContextImpl::singleton().getCurrentRuntime()->stream();
}
infiniopHandle_t getInfiniopHandle() {
infiniopHandle_t getInfiniopHandle(Device device) {
if (device.getType() == Device::Type::CPU) {
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
}
if (device != getDevice()) {
throw std::runtime_error("Requested device doesn't match current runtime.");
}
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
}
......
......@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
......@@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(out->device()), &desc,
out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc);
......
......@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
......
......@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
......@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
......@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) {
infiniopRearrangeDescriptor_t desc = nullptr;
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(), &desc, y->desc(), x->desc()));
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc()));
cache.put(seed, desc);
} else {
desc = *desc_opt;
......
......@@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(y->device()), &desc,
y->desc(), x->desc(), weight->desc(), epsilon));
cache.put(seed, desc);
} else {
......
......@@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(x_out->device()), &desc,
x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo));
......
......@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc()));
cache.put(seed, desc);
} else {
......
......@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(), &desc,
context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc()));
cache.put(seed, desc);
} else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment