Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FAST-RNNT
Commits
5a3e2552
Commit
5a3e2552
authored
Mar 08, 2022
by
pkufool
Browse files
Add DeviceGuard
parent
d53e923b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
102 additions
and
11 deletions
+102
-11
fast_rnnt/csrc/CMakeLists.txt
fast_rnnt/csrc/CMakeLists.txt
+1
-0
fast_rnnt/csrc/device_guard.h
fast_rnnt/csrc/device_guard.h
+86
-0
fast_rnnt/csrc/mutual_information_cpu.cc
fast_rnnt/csrc/mutual_information_cpu.cc
+6
-5
fast_rnnt/python/csrc/mutual_information.cu
fast_rnnt/python/csrc/mutual_information.cu
+9
-6
No files found.
fast_rnnt/csrc/CMakeLists.txt
View file @
5a3e2552
...
...
@@ -11,6 +11,7 @@ if(FT_WITH_CUDA)
set
(
cuda_srcs mutual_information_cuda.cu
)
add_library
(
mutual_information_core_cuda
${
cuda_srcs
}
)
target_link_libraries
(
mutual_information_core_cuda PUBLIC
${
TORCH_LIBRARIES
}
)
# for <torch/extension.h>
target_include_directories
(
mutual_information_core_cuda PUBLIC
${
PYTHON_INCLUDE_DIRS
}
)
target_link_libraries
(
mutual_information_core PUBLIC mutual_information_core_cuda
)
endif
()
fast_rnnt/csrc/device_guard.h
0 → 100644
View file @
5a3e2552
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef FAST_RNNT_CSRC_DEVICE_GUARD_H_
#define FAST_RNNT_CSRC_DEVICE_GUARD_H_
#include <torch/script.h>
// This file is modified from
// https://github.com/k2-fsa/k2/blob/master/k2/csrc/device_guard.h
namespace
fast_rnnt
{
// DeviceGuard is an RAII class. Its sole purpose is to restore
// the previous default cuda device if a CUDA context changes the
// current default cuda device.
class
DeviceGuard
{
public:
explicit
DeviceGuard
(
torch
::
Device
device
)
{
if
(
device
.
type
()
==
torch
::
kCUDA
)
{
old_device_
=
GetDevice
();
new_device_
=
device
.
index
();
if
(
old_device_
!=
new_device_
)
SetDevice
(
new_device_
);
}
// else do nothing
}
explicit
DeviceGuard
(
int32_t
new_device
)
:
new_device_
(
new_device
)
{
if
(
new_device
!=
-
1
)
{
old_device_
=
GetDevice
();
if
(
old_device_
!=
new_device
)
SetDevice
(
new_device
);
}
}
~
DeviceGuard
()
{
if
(
old_device_
!=
-
1
&&
old_device_
!=
new_device_
)
{
// restore the previous device
SetDevice
(
old_device_
);
}
// else it was either a CPU context or the device IDs
// were the same
}
DeviceGuard
(
const
DeviceGuard
&
)
=
delete
;
DeviceGuard
&
operator
=
(
const
DeviceGuard
&
)
=
delete
;
DeviceGuard
(
DeviceGuard
&&
)
=
delete
;
DeviceGuard
&
operator
=
(
DeviceGuard
&&
)
=
delete
;
private:
static
int32_t
GetDevice
()
{
int32_t
device
;
auto
s
=
cudaGetDevice
(
&
device
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
return
device
;
}
static
void
SetDevice
(
int32_t
device
)
{
auto
s
=
cudaSetDevice
(
device
);
TORCH_CHECK
(
s
==
cudaSuccess
,
cudaGetErrorString
(
s
));
}
private:
int32_t
old_device_
=
-
1
;
int32_t
new_device_
=
-
1
;
};
}
// namespace fast_rnnt
#endif // FAST_RNNT_CSRC_DEVICE_GUARD_H_
fast_rnnt/csrc/mutual_information_cpu.cc
View file @
5a3e2552
...
...
@@ -18,6 +18,7 @@
* limitations under the License.
*/
#include <iostream>
#include "fast_rnnt/csrc/mutual_information.h"
namespace
fast_rnnt
{
...
...
@@ -241,11 +242,11 @@ MutualInformationBackwardCpu(torch::Tensor px, torch::Tensor py,
if
(
ans_grad_a
[
b
]
!=
0.0
)
{
float
grad_ratio
=
p_grad_a
[
b
][
s_begin
][
t_begin
]
/
ans_grad_a
[
b
];
if
(
fabs
(
grad_ratio
-
1.0
)
>
0.01
)
{
// K2_LOG(WARNING)
//
<< "Warning: mutual_information backprop: expected these "
//
<< "numbers to be the same:"
//
<< static_cast<float>(p_grad_a[b][s_begin][t_begin]) << " vs "
//
<< static_cast<float>(ans_grad_a[b]);
std
::
cout
<<
"Warning: mutual_information backprop: expected these "
<<
"numbers to be the same:"
<<
static_cast
<
float
>
(
p_grad_a
[
b
][
s_begin
][
t_begin
])
<<
" vs "
<<
static_cast
<
float
>
(
ans_grad_a
[
b
]);
}
}
}
...
...
fast_rnnt/python/csrc/mutual_information.cu
View file @
5a3e2552
...
...
@@ -18,6 +18,7 @@
* limitations under the License.
*/
#include "fast_rnnt/csrc/device_guard.h"
#include "fast_rnnt/csrc/mutual_information.h"
#include "fast_rnnt/python/csrc/mutual_information.h"
...
...
@@ -29,14 +30,15 @@ PYBIND11_MODULE(_fast_rnnt, m) {
[](
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
optional
<
torch
::
Tensor
>
boundary
,
torch
::
Tensor
p
)
->
torch
::
Tensor
{
fast_rnnt
::
DeviceGuard
guard
(
px
.
device
());
if
(
px
.
device
().
is_cpu
())
{
return
fast_rnnt
::
MutualInformationCpu
(
px
,
py
,
boundary
,
p
);
}
else
{
#ifdef FT_WITH_CUDA
return
fast_rnnt
::
MutualInformationCuda
(
px
,
py
,
boundary
,
p
);
#else
//K2_LOG(FATAL) <<
"Failed to find native CUDA module, make sure "
//<<
"that you compiled the code with K2_WITH_CUDA.";
TORCH_CHECK
(
false
,
"Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA."
)
;
return
torch
::
Tensor
();
#endif
}
...
...
@@ -48,16 +50,17 @@ PYBIND11_MODULE(_fast_rnnt, m) {
[](
torch
::
Tensor
px
,
torch
::
Tensor
py
,
torch
::
optional
<
torch
::
Tensor
>
boundary
,
torch
::
Tensor
p
,
torch
::
Tensor
ans_grad
)
->
std
::
vector
<
torch
::
Tensor
>
{
fast_rnnt
::
DeviceGuard
guard
(
px
.
device
());
if
(
px
.
device
().
is_cpu
())
{
return
fast_rnnt
::
MutualInformationBackwardCpu
(
px
,
py
,
boundary
,
p
,
ans_grad
);
ans_grad
);
}
else
{
#ifdef FT_WITH_CUDA
return
fast_rnnt
::
MutualInformationBackwardCuda
(
px
,
py
,
boundary
,
p
,
ans_grad
,
true
);
ans_grad
,
true
);
#else
//K2_LOG(FATAL) <<
"Failed to find native CUDA module, make sure "
//<<
"that you compiled the code with K2_WITH_CUDA.";
TORCH_CHECK
(
false
,
"Failed to find native CUDA module, make sure "
"that you compiled the code with K2_WITH_CUDA."
)
;
return
std
::
vector
<
torch
::
Tensor
>
();
#endif
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment