Commit 5a3e2552 authored by pkufool's avatar pkufool
Browse files

Add DeviceGuard

parent d53e923b
...@@ -11,6 +11,7 @@ if(FT_WITH_CUDA) ...@@ -11,6 +11,7 @@ if(FT_WITH_CUDA)
set(cuda_srcs mutual_information_cuda.cu) set(cuda_srcs mutual_information_cuda.cu)
add_library(mutual_information_core_cuda ${cuda_srcs}) add_library(mutual_information_core_cuda ${cuda_srcs})
target_link_libraries(mutual_information_core_cuda PUBLIC ${TORCH_LIBRARIES}) target_link_libraries(mutual_information_core_cuda PUBLIC ${TORCH_LIBRARIES})
# for <torch/extension.h>
target_include_directories(mutual_information_core_cuda PUBLIC ${PYTHON_INCLUDE_DIRS}) target_include_directories(mutual_information_core_cuda PUBLIC ${PYTHON_INCLUDE_DIRS})
target_link_libraries(mutual_information_core PUBLIC mutual_information_core_cuda) target_link_libraries(mutual_information_core PUBLIC mutual_information_core_cuda)
endif() endif()
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_DEVICE_GUARD_H_
#define FAST_RNNT_CSRC_DEVICE_GUARD_H_
#include <torch/script.h>
// This file is modified from
// https://github.com/k2-fsa/k2/blob/master/k2/csrc/device_guard.h
namespace fast_rnnt {
// DeviceGuard is an RAII class. Its sole purpose is to restore
// the previous default cuda device if a CUDA context changes the
// current default cuda device.
class DeviceGuard {
public:
explicit DeviceGuard(torch::Device device) {
if (device.type() == torch::kCUDA) {
old_device_ = GetDevice();
new_device_ = device.index();
if (old_device_ != new_device_)
SetDevice(new_device_);
}
// else do nothing
}
explicit DeviceGuard(int32_t new_device) : new_device_(new_device) {
if (new_device != -1) {
old_device_ = GetDevice();
if (old_device_ != new_device)
SetDevice(new_device);
}
}
~DeviceGuard() {
if (old_device_ != -1 && old_device_ != new_device_) {
// restore the previous device
SetDevice(old_device_);
}
// else it was either a CPU context or the device IDs
// were the same
}
DeviceGuard(const DeviceGuard &) = delete;
DeviceGuard &operator=(const DeviceGuard &) = delete;
DeviceGuard(DeviceGuard &&) = delete;
DeviceGuard &operator=(DeviceGuard &&) = delete;
private:
static int32_t GetDevice() {
int32_t device;
auto s = cudaGetDevice(&device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
return device;
}
static void SetDevice(int32_t device) {
auto s = cudaSetDevice(device);
TORCH_CHECK(s == cudaSuccess, cudaGetErrorString(s));
}
private:
int32_t old_device_ = -1;
int32_t new_device_ = -1;
};
} // namespace fast_rnnt
#endif // FAST_RNNT_CSRC_DEVICE_GUARD_H_
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include <iostream>
#include "fast_rnnt/csrc/mutual_information.h" #include "fast_rnnt/csrc/mutual_information.h"
namespace fast_rnnt { namespace fast_rnnt {
...@@ -241,11 +242,11 @@ MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py, ...@@ -241,11 +242,11 @@ MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
if (ans_grad_a[b] != 0.0) { if (ans_grad_a[b] != 0.0) {
float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b]; float grad_ratio = p_grad_a[b][s_begin][t_begin] / ans_grad_a[b];
if (fabs(grad_ratio - 1.0) > 0.01) { if (fabs(grad_ratio - 1.0) > 0.01) {
// K2_LOG(WARNING) std::cout
//<< "Warning: mutual_information backprop: expected these " << "Warning: mutual_information backprop: expected these "
//<< "numbers to be the same:" << "numbers to be the same:"
//<< static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs " << static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs "
//<< static_cast<float>(ans_grad_a[b]); << static_cast<float>(ans_grad_a[b]);
} }
} }
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/mutual_information.h" #include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h" #include "fast_rnnt/python/csrc/mutual_information.h"
...@@ -29,14 +30,15 @@ PYBIND11_MODULE(_fast_rnnt, m) { ...@@ -29,14 +30,15 @@ PYBIND11_MODULE(_fast_rnnt, m) {
[](torch::Tensor px, torch::Tensor py, [](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::optional<torch::Tensor> boundary,
torch::Tensor p) -> torch::Tensor { torch::Tensor p) -> torch::Tensor {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) { if (px.device().is_cpu()) {
return fast_rnnt::MutualInformationCpu(px, py, boundary, p); return fast_rnnt::MutualInformationCpu(px, py, boundary, p);
} else { } else {
#ifdef FT_WITH_CUDA #ifdef FT_WITH_CUDA
return fast_rnnt::MutualInformationCuda(px, py, boundary, p); return fast_rnnt::MutualInformationCuda(px, py, boundary, p);
#else #else
//K2_LOG(FATAL) << "Failed to find native CUDA module, make sure " TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
//<< "that you compiled the code with K2_WITH_CUDA."; "that you compiled the code with K2_WITH_CUDA.");
return torch::Tensor(); return torch::Tensor();
#endif #endif
} }
...@@ -48,6 +50,7 @@ PYBIND11_MODULE(_fast_rnnt, m) { ...@@ -48,6 +50,7 @@ PYBIND11_MODULE(_fast_rnnt, m) {
[](torch::Tensor px, torch::Tensor py, [](torch::Tensor px, torch::Tensor py,
torch::optional<torch::Tensor> boundary, torch::Tensor p, torch::optional<torch::Tensor> boundary, torch::Tensor p,
torch::Tensor ans_grad) -> std::vector<torch::Tensor> { torch::Tensor ans_grad) -> std::vector<torch::Tensor> {
fast_rnnt::DeviceGuard guard(px.device());
if (px.device().is_cpu()) { if (px.device().is_cpu()) {
return fast_rnnt::MutualInformationBackwardCpu(px, py, boundary, p, return fast_rnnt::MutualInformationBackwardCpu(px, py, boundary, p,
ans_grad); ans_grad);
...@@ -56,8 +59,8 @@ PYBIND11_MODULE(_fast_rnnt, m) { ...@@ -56,8 +59,8 @@ PYBIND11_MODULE(_fast_rnnt, m) {
return fast_rnnt::MutualInformationBackwardCuda(px, py, boundary, p, return fast_rnnt::MutualInformationBackwardCuda(px, py, boundary, p,
ans_grad, true); ans_grad, true);
#else #else
//K2_LOG(FATAL) << "Failed to find native CUDA module, make sure " TORCH_CHECK(false, "Failed to find native CUDA module, make sure "
//<< "that you compiled the code with K2_WITH_CUDA."; "that you compiled the code with K2_WITH_CUDA.");
return std::vector<torch::Tensor>(); return std::vector<torch::Tensor>();
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment