Commit 515e1eca authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/1033 support stream guard

parent 8ab073b4
#pragma once #pragma once
#include "../context/context.hpp"
#include "../tensor.hpp" #include "../tensor.hpp"
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
namespace infinicore::adaptor { namespace infinicore::adaptor {
inline at::ScalarType to_at_dtype(DataType dtype) { inline at::ScalarType to_at_dtype(DataType dtype) {
switch (dtype) { switch (dtype) {
...@@ -32,4 +36,6 @@ inline at::Device to_at_device(const Device &device) { ...@@ -32,4 +36,6 @@ inline at::Device to_at_device(const Device &device) {
} }
at::Tensor to_aten_tensor(const infinicore::Tensor &t); at::Tensor to_aten_tensor(const infinicore::Tensor &t);
c10::cuda::CUDAStream get_cuda_stream();
} // namespace infinicore::adaptor } // namespace infinicore::adaptor
\ No newline at end of file
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
namespace infinicore::adaptor { namespace infinicore::adaptor {
at::Tensor to_aten_tensor(const infinicore::Tensor &t) { at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
void *data_ptr = (void *)(t->data()); void *data_ptr = (void *)(t->data());
...@@ -31,4 +30,9 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) { ...@@ -31,4 +30,9 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
deleter_, deleter_,
options); options);
} }
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
}
} // namespace infinicore::adaptor } // namespace infinicore::adaptor
\ No newline at end of file
...@@ -38,6 +38,7 @@ void *plan(Tensor out, ...@@ -38,6 +38,7 @@ void *plan(Tensor out,
} }
void run(void *planned_meta) { void run(void *planned_meta) {
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta); auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
auto q = infinicore::adaptor::to_aten_tensor(p->q); auto q = infinicore::adaptor::to_aten_tensor(p->q);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment