Commit 4a4b6e8e authored by PanZezhong's avatar PanZezhong
Browse files

issue/811 remove shortcut for cpu runtime

parent 180674dc
...@@ -29,18 +29,16 @@ Runtime *ContextImpl::getCurrentRuntime() { ...@@ -29,18 +29,16 @@ Runtime *ContextImpl::getCurrentRuntime() {
return current_runtime_; return current_runtime_;
} }
Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get();
}
void ContextImpl::setDevice(Device device) { void ContextImpl::setDevice(Device device) {
if (device == getCurrentRuntime()->device()) { if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set. // Do nothing if the device is already set.
return; return;
} }
if (getCurrentRuntime()->isGraphRecording()) { thread_local bool warn_switch_runtime = false;
if (getCurrentRuntime()->isGraphRecording() && !warn_switch_runtime) {
spdlog::warn("Switching device runtime during graph recording may break the graph!"); spdlog::warn("Switching device runtime during graph recording may break the graph!");
warn_switch_runtime = true;
} }
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) { if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
...@@ -104,11 +102,8 @@ infinirtStream_t getStream() { ...@@ -104,11 +102,8 @@ infinirtStream_t getStream() {
} }
infiniopHandle_t getInfiniopHandle(Device device) { infiniopHandle_t getInfiniopHandle(Device device) {
if (device.getType() == Device::Type::CPU) {
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
}
if (device != getDevice()) { if (device != getDevice()) {
throw std::runtime_error("Requested device doesn't match current runtime."); setDevice(device);
} }
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle(); return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
} }
...@@ -127,7 +122,7 @@ std::shared_ptr<Memory> allocateMemory(size_t size) { ...@@ -127,7 +122,7 @@ std::shared_ptr<Memory> allocateMemory(size_t size) {
std::shared_ptr<Memory> allocateHostMemory(size_t size) { std::shared_ptr<Memory> allocateHostMemory(size_t size) {
setDevice(Device::cpu()); setDevice(Device::cpu());
return ContextImpl::singleton().getCpuRuntime()->allocateMemory(size); return allocateMemory(size);
} }
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size) { std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size) {
...@@ -147,7 +142,8 @@ void memcpyD2D(void *dst, const void *src, size_t size, bool async) { ...@@ -147,7 +142,8 @@ void memcpyD2D(void *dst, const void *src, size_t size, bool async) {
} }
void memcpyH2H(void *dst, const void *src, size_t size) { void memcpyH2H(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCpuRuntime()->memcpyD2D(dst, src, size); setDevice(Device::cpu());
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size);
} }
// Timing API implementations // Timing API implementations
......
...@@ -19,8 +19,6 @@ protected: ...@@ -19,8 +19,6 @@ protected:
public: public:
Runtime *getCurrentRuntime(); Runtime *getCurrentRuntime();
Runtime *getCpuRuntime();
void setDevice(Device); void setDevice(Device);
size_t getDeviceCount(Device::Type type); size_t getDeviceCount(Device::Type type);
......
...@@ -19,7 +19,8 @@ Tensor TensorImpl::to(Device device) const { ...@@ -19,7 +19,8 @@ Tensor TensorImpl::to(Device device) const {
void TensorImpl::copy_from(Tensor src) { void TensorImpl::copy_from(Tensor src) {
if (src->shape() != this->shape()) { if (src->shape() != this->shape()) {
throw std::runtime_error("Cannot copy from tensor with different shape"); throw std::runtime_error(
"Cannot copy from tensor with different shape. Src: " + src->info() + " Dst: " + this->info());
} }
if (this->device() == src->device()) { if (this->device() == src->device()) {
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), src); op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), src);
...@@ -31,11 +32,12 @@ void TensorImpl::copy_from(Tensor src) { ...@@ -31,11 +32,12 @@ void TensorImpl::copy_from(Tensor src) {
// Use nbytes() to get the actual tensor size, not the full memory size // Use nbytes() to get the actual tensor size, not the full memory size
size_t copy_size = std::min(this->nbytes(), src->nbytes()); size_t copy_size = std::min(this->nbytes(), src->nbytes());
if (this->device().getType() == Device::Type::CPU) { if (this->device().getType() == Device::Type::CPU) {
context::setDevice(src->device());
if (this->is_contiguous()) { if (this->is_contiguous()) {
context::setDevice(src->device());
context::memcpyD2H(this->data(), src->data(), copy_size); context::memcpyD2H(this->data(), src->data(), copy_size);
} else { } else {
auto local_src = Tensor::empty(this->shape(), this->dtype(), this->device()); auto local_src = Tensor::empty(this->shape(), this->dtype(), this->device());
context::setDevice(src->device());
context::memcpyD2H(local_src->data(), src->data(), this->data_.memory->size()); context::memcpyD2H(local_src->data(), src->data(), this->data_.memory->size());
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src); op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src);
} }
......
...@@ -29,7 +29,10 @@ inline struct SpdlogInitializer { ...@@ -29,7 +29,10 @@ inline struct SpdlogInitializer {
infiniStatus_t ret = (call); \ infiniStatus_t ret = (call); \
SPDLOG_DEBUG("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \ SPDLOG_DEBUG("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
if (ret != INFINI_STATUS_SUCCESS) { \ if (ret != INFINI_STATUS_SUCCESS) { \
throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \ throw std::runtime_error("`" #call "` failed with error: " + std::string(infini_status_string(ret)) \
+ " from " + std::string(__func__) \
+ " at " + std::string(__FILE__) \
+ ":" + std::to_string(__LINE__) + "."); \
} \ } \
} while (false) } while (false)
......
...@@ -4,6 +4,14 @@ ...@@ -4,6 +4,14 @@
#define CHECK_CUDART(RT_API) CHECK_INTERNAL(RT_API, cudaSuccess) #define CHECK_CUDART(RT_API) CHECK_INTERNAL(RT_API, cudaSuccess)
#define RUN_CUDART(RT_API) \
do { \
auto api_result_ = (RT_API); \
if (api_result_ != (cudaSuccess)) { \
{ return INFINI_STATUS_INTERNAL_ERROR; } \
} \
} while (0)
// 根据宏定义选择命名空间并实现 // 根据宏定义选择命名空间并实现
#if defined(ENABLE_NVIDIA_API) #if defined(ENABLE_NVIDIA_API)
namespace infinirt::cuda { namespace infinirt::cuda {
...@@ -40,7 +48,7 @@ infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) { ...@@ -40,7 +48,7 @@ infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) {
} }
infiniStatus_t streamDestroy(infinirtStream_t stream) { infiniStatus_t streamDestroy(infinirtStream_t stream) {
CHECK_CUDART(cudaStreamDestroy((cudaStream_t)stream)); RUN_CUDART(cudaStreamDestroy((cudaStream_t)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -105,7 +113,7 @@ infiniStatus_t eventSynchronize(infinirtEvent_t event) { ...@@ -105,7 +113,7 @@ infiniStatus_t eventSynchronize(infinirtEvent_t event) {
} }
infiniStatus_t eventDestroy(infinirtEvent_t event) { infiniStatus_t eventDestroy(infinirtEvent_t event) {
CHECK_CUDART(cudaEventDestroy((cudaEvent_t)event)); RUN_CUDART(cudaEventDestroy((cudaEvent_t)event));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -125,12 +133,12 @@ infiniStatus_t mallocHost(void **p_ptr, size_t size) { ...@@ -125,12 +133,12 @@ infiniStatus_t mallocHost(void **p_ptr, size_t size) {
} }
infiniStatus_t freeDevice(void *ptr) { infiniStatus_t freeDevice(void *ptr) {
CHECK_CUDART(cudaFree(ptr)); RUN_CUDART(cudaFree(ptr));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
infiniStatus_t freeHost(void *ptr) { infiniStatus_t freeHost(void *ptr) {
CHECK_CUDART(cudaFreeHost(ptr)); RUN_CUDART(cudaFreeHost(ptr));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
...@@ -165,7 +173,7 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) { ...@@ -165,7 +173,7 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
} }
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) { infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream)); RUN_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment