Commit 16854aed authored by wooway777's avatar wooway777
Browse files

issue/591 - fix operator context mismatch

parent a311e9c8
...@@ -16,7 +16,7 @@ Device getDevice(); ...@@ -16,7 +16,7 @@ Device getDevice();
size_t getDeviceCount(Device::Type type); size_t getDeviceCount(Device::Type type);
infinirtStream_t getStream(); infinirtStream_t getStream();
infiniopHandle_t getInfiniopHandle(); infiniopHandle_t getInfiniopHandle(Device device);
void syncStream(); void syncStream();
void syncDevice(); void syncDevice();
......
...@@ -99,7 +99,13 @@ infinirtStream_t getStream() { ...@@ -99,7 +99,13 @@ infinirtStream_t getStream() {
return ContextImpl::singleton().getCurrentRuntime()->stream(); return ContextImpl::singleton().getCurrentRuntime()->stream();
} }
infiniopHandle_t getInfiniopHandle() { infiniopHandle_t getInfiniopHandle(Device device) {
if (device.getType() == Device::Type::CPU) {
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
}
if (device != getDevice()) {
throw std::runtime_error("Requested device doesn't match current runtime.");
}
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle(); return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
} }
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor ...@@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(out->device()), &desc,
out->desc(), q->desc(), k->desc(), v->desc(), out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos)); k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc); cache.put(seed, desc);
......
...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) { ...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc())); output->desc(), input->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) { ...@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) {
infiniopRearrangeDescriptor_t desc = nullptr; infiniopRearrangeDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(), &desc, y->desc(), x->desc())); INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
desc = *desc_opt; desc = *desc_opt;
......
...@@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) { ...@@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(y->device()), &desc,
y->desc(), x->desc(), weight->desc(), epsilon)); y->desc(), x->desc(), weight->desc(), epsilon));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s ...@@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(x_out->device()), &desc,
x_out->desc(), x->desc(), x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(), pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo)); infiniop_algo));
......
...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) { ...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc())); output->desc(), input->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment