Commit 4a4b6e8e authored by PanZezhong's avatar PanZezhong
Browse files

issue/811 remove shortcut for cpu runtime

parent 180674dc
......@@ -29,18 +29,16 @@ Runtime *ContextImpl::getCurrentRuntime() {
return current_runtime_;
}
Runtime *ContextImpl::getCpuRuntime() {
return runtime_table_[int(Device::Type::CPU)][0].get();
}
void ContextImpl::setDevice(Device device) {
if (device == getCurrentRuntime()->device()) {
// Do nothing if the device is already set.
return;
}
if (getCurrentRuntime()->isGraphRecording()) {
thread_local bool warn_switch_runtime = false;
if (getCurrentRuntime()->isGraphRecording() && !warn_switch_runtime) {
spdlog::warn("Switching device runtime during graph recording may break the graph!");
warn_switch_runtime = true;
}
if (runtime_table_[int(device.getType())][device.getIndex()] == nullptr) {
......@@ -104,11 +102,8 @@ infinirtStream_t getStream() {
}
infiniopHandle_t getInfiniopHandle(Device device) {
if (device.getType() == Device::Type::CPU) {
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
}
if (device != getDevice()) {
throw std::runtime_error("Requested device doesn't match current runtime.");
setDevice(device);
}
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
}
......@@ -127,7 +122,7 @@ std::shared_ptr<Memory> allocateMemory(size_t size) {
std::shared_ptr<Memory> allocateHostMemory(size_t size) {
setDevice(Device::cpu());
return ContextImpl::singleton().getCpuRuntime()->allocateMemory(size);
return allocateMemory(size);
}
std::shared_ptr<Memory> allocatePinnedHostMemory(size_t size) {
......@@ -147,7 +142,8 @@ void memcpyD2D(void *dst, const void *src, size_t size, bool async) {
}
void memcpyH2H(void *dst, const void *src, size_t size) {
return ContextImpl::singleton().getCpuRuntime()->memcpyD2D(dst, src, size);
setDevice(Device::cpu());
return ContextImpl::singleton().getCurrentRuntime()->memcpyD2D(dst, src, size);
}
// Timing API implementations
......
......@@ -19,8 +19,6 @@ protected:
public:
Runtime *getCurrentRuntime();
Runtime *getCpuRuntime();
void setDevice(Device);
size_t getDeviceCount(Device::Type type);
......
......@@ -19,7 +19,8 @@ Tensor TensorImpl::to(Device device) const {
void TensorImpl::copy_from(Tensor src) {
if (src->shape() != this->shape()) {
throw std::runtime_error("Cannot copy from tensor with different shape");
throw std::runtime_error(
"Cannot copy from tensor with different shape. Src: " + src->info() + " Dst: " + this->info());
}
if (this->device() == src->device()) {
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), src);
......@@ -31,11 +32,12 @@ void TensorImpl::copy_from(Tensor src) {
// Use nbytes() to get the actual tensor size, not the full memory size
size_t copy_size = std::min(this->nbytes(), src->nbytes());
if (this->device().getType() == Device::Type::CPU) {
context::setDevice(src->device());
if (this->is_contiguous()) {
context::setDevice(src->device());
context::memcpyD2H(this->data(), src->data(), copy_size);
} else {
auto local_src = Tensor::empty(this->shape(), this->dtype(), this->device());
context::setDevice(src->device());
context::memcpyD2H(local_src->data(), src->data(), this->data_.memory->size());
op::rearrange_(Tensor(const_cast<TensorImpl *>(this)->shared_from_this()), local_src);
}
......
......@@ -29,7 +29,10 @@ inline struct SpdlogInitializer {
infiniStatus_t ret = (call); \
SPDLOG_DEBUG("Exiting `" #call "` at `" __FILE__ ":" STRINGIZE(__LINE__) "`."); \
if (ret != INFINI_STATUS_SUCCESS) { \
throw std::runtime_error(#call " failed with error: " + std::string(infini_status_string(ret))); \
throw std::runtime_error("`" #call "` failed with error: " + std::string(infini_status_string(ret)) \
+ " from " + std::string(__func__) \
+ " at " + std::string(__FILE__) \
+ ":" + std::to_string(__LINE__) + "."); \
} \
} while (false)
......
......@@ -4,6 +4,14 @@
#define CHECK_CUDART(RT_API) CHECK_INTERNAL(RT_API, cudaSuccess)
#define RUN_CUDART(RT_API) \
do { \
auto api_result_ = (RT_API); \
if (api_result_ != (cudaSuccess)) { \
{ return INFINI_STATUS_INTERNAL_ERROR; } \
} \
} while (0)
// 根据宏定义选择命名空间并实现
#if defined(ENABLE_NVIDIA_API)
namespace infinirt::cuda {
......@@ -40,7 +48,7 @@ infiniStatus_t streamCreate(infinirtStream_t *stream_ptr) {
}
infiniStatus_t streamDestroy(infinirtStream_t stream) {
CHECK_CUDART(cudaStreamDestroy((cudaStream_t)stream));
RUN_CUDART(cudaStreamDestroy((cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
......@@ -105,7 +113,7 @@ infiniStatus_t eventSynchronize(infinirtEvent_t event) {
}
infiniStatus_t eventDestroy(infinirtEvent_t event) {
CHECK_CUDART(cudaEventDestroy((cudaEvent_t)event));
RUN_CUDART(cudaEventDestroy((cudaEvent_t)event));
return INFINI_STATUS_SUCCESS;
}
......@@ -125,12 +133,12 @@ infiniStatus_t mallocHost(void **p_ptr, size_t size) {
}
infiniStatus_t freeDevice(void *ptr) {
CHECK_CUDART(cudaFree(ptr));
RUN_CUDART(cudaFree(ptr));
return INFINI_STATUS_SUCCESS;
}
infiniStatus_t freeHost(void *ptr) {
CHECK_CUDART(cudaFreeHost(ptr));
RUN_CUDART(cudaFreeHost(ptr));
return INFINI_STATUS_SUCCESS;
}
......@@ -165,7 +173,7 @@ infiniStatus_t mallocAsync(void **p_ptr, size_t size, infinirtStream_t stream) {
}
infiniStatus_t freeAsync(void *ptr, infinirtStream_t stream) {
CHECK_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
RUN_CUDART(cudaFreeAsync(ptr, (cudaStream_t)stream));
return INFINI_STATUS_SUCCESS;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment