Commit 515e1eca authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/1033 support stream guard

parent 8ab073b4
#pragma once
#include "../context/context.hpp"
#include "../tensor.hpp"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace infinicore::adaptor {
inline at::ScalarType to_at_dtype(DataType dtype) {
switch (dtype) {
......@@ -32,4 +36,6 @@ inline at::Device to_at_device(const Device &device) {
}
at::Tensor to_aten_tensor(const infinicore::Tensor &t);
c10::cuda::CUDAStream get_cuda_stream();
} // namespace infinicore::adaptor
\ No newline at end of file
......@@ -2,7 +2,6 @@
namespace infinicore::adaptor {
at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
void *data_ptr = (void *)(t->data());
......@@ -31,4 +30,9 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
deleter_,
options);
}
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
} // namespace infinicore::adaptor
\ No newline at end of file
......@@ -38,6 +38,7 @@ void *plan(Tensor out,
}
void run(void *planned_meta) {
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment