Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
5417e4fb
Unverified
Commit
5417e4fb
authored
May 06, 2021
by
Caroline Chen
Committed by
GitHub
May 06, 2021
Browse files
Add GPU RNNT Loss (#1483)
parent
7d45851d
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1486 additions
and
0 deletions
+1486
-0
CMakeLists.txt
CMakeLists.txt
+5
-0
build_tools/setup_helpers/extension.py
build_tools/setup_helpers/extension.py
+2
-0
test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py
test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py
+10
-0
torchaudio/csrc/CMakeLists.txt
torchaudio/csrc/CMakeLists.txt
+15
-0
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+105
-0
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
+73
-0
torchaudio/csrc/rnnt/gpu/compute_betas.cu
torchaudio/csrc/rnnt/gpu/compute_betas.cu
+78
-0
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
+98
-0
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
+422
-0
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
+395
-0
torchaudio/csrc/rnnt/gpu/half.cuh
torchaudio/csrc/rnnt/gpu/half.cuh
+38
-0
torchaudio/csrc/rnnt/gpu/kernel_utils.h
torchaudio/csrc/rnnt/gpu/kernel_utils.h
+66
-0
torchaudio/csrc/rnnt/gpu/kernels.h
torchaudio/csrc/rnnt/gpu/kernels.h
+131
-0
torchaudio/csrc/rnnt/gpu/math.cuh
torchaudio/csrc/rnnt/gpu/math.cuh
+48
-0
No files found.
CMakeLists.txt
View file @
5417e4fb
...
...
@@ -59,6 +59,11 @@ option(BUILD_KALDI "Build kaldi statically" ON)
option
(
BUILD_TRANSDUCER
"Enable transducer"
OFF
)
option
(
BUILD_LIBTORCHAUDIO
"Build C++ Library"
ON
)
option
(
BUILD_TORCHAUDIO_PYTHON_EXTENSION
"Build Python extension"
OFF
)
option
(
USE_CUDA
"Enable CUDA support"
OFF
)
if
(
USE_CUDA
)
enable_language
(
CUDA
)
endif
()
find_package
(
Torch REQUIRED
)
...
...
build_tools/setup_helpers/extension.py
View file @
5417e4fb
...
...
@@ -38,6 +38,7 @@ _BUILD_SOX = False if platform.system() == 'Windows' else _get_build("BUILD_SOX"
_BUILD_KALDI
=
False
if
platform
.
system
()
==
'Windows'
else
_get_build
(
"BUILD_KALDI"
,
True
)
_BUILD_TRANSDUCER
=
_get_build
(
"BUILD_TRANSDUCER"
)
_USE_ROCM
=
_get_build
(
"USE_ROCM"
)
_USE_CUDA
=
torch
.
cuda
.
is_available
()
def
get_ext_modules
():
...
...
@@ -76,6 +77,7 @@ class CMakeBuild(build_ext):
"-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON"
,
"-DBUILD_LIBTORCHAUDIO:BOOL=OFF"
,
f
"-DUSE_ROCM:BOOL=
{
'ON'
if
_USE_ROCM
else
'OFF'
}
"
,
f
"-DUSE_CUDA:BOOL=
{
'ON'
if
_USE_CUDA
else
'OFF'
}
"
,
]
build_args
=
[
'--target'
,
'install'
...
...
test/torchaudio_unittest/rnnt/rnnt_loss_cuda_test.py
0 → 100644
View file @
5417e4fb
import
torch
from
.rnnt_loss_impl
import
RNNTLossTest
from
torchaudio_unittest
import
common_utils
from
.utils
import
skipIfNoTransducer
@
skipIfNoTransducer
@
common_utils
.
skipIfNoCuda
class
TestRNNTLoss
(
RNNTLossTest
,
common_utils
.
PytorchTestCase
):
device
=
torch
.
device
(
'cuda'
)
torchaudio/csrc/CMakeLists.txt
View file @
5417e4fb
...
...
@@ -20,6 +20,17 @@ if(BUILD_TRANSDUCER)
rnnt/compute_betas.cpp
rnnt/compute.cpp
)
if
(
USE_CUDA
)
set
(
CUDA_TRANSDUCER_SOURCES
rnnt/gpu/compute_alphas.cu
rnnt/gpu/compute_betas.cu
rnnt/gpu/compute.cu
)
list
(
APPEND TRANSDUCER_SOURCES
${
CUDA_TRANSDUCER_SOURCES
}
)
endif
()
list
(
APPEND LIBTORCHAUDIO_SOURCES
${
TRANSDUCER_SOURCES
}
)
endif
()
...
...
@@ -105,6 +116,10 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
target_compile_definitions
(
_torchaudio PRIVATE INCLUDE_KALDI
)
endif
()
if
(
USE_CUDA
)
target_compile_definitions
(
_torchaudio PRIVATE USE_CUDA
)
endif
()
target_include_directories
(
_torchaudio
PRIVATE
...
...
torchaudio/csrc/rnnt/gpu/compute.cu
0 → 100644
View file @
5417e4fb
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
// Entry point into RNNT Loss
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src_lengths
,
const
torch
::
Tensor
&
tgt_lengths
,
int64_t
blank
,
double
clamp
,
bool
fused_log_smax
=
true
,
bool
reuse_logits_for_grads
=
true
)
{
Options
options
;
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
options
.
fusedLogSmax_
=
fused_log_smax
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
options
.
batchSize_
*
options
.
nHypos_
,
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
c10
::
nullopt
;
if
(
logits
.
requires_grad
())
{
if
(
reuse_logits_for_grads
)
{
gradients
=
logits
;
}
else
{
gradients
=
torch
::
zeros_like
(
logits
);
}
}
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
switch
(
logits
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
{
Compute
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*src_lengths=*/
src_lengths
.
data_ptr
<
int
>
(),
/*tgt_lengths=*/
tgt_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
float
>
());
break
;
}
case
torch
::
ScalarType
::
Half
:
{
Compute
<
/*DTYPE=*/
c10
::
Half
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*src_lengths=*/
src_lengths
.
data_ptr
<
int
>
(),
/*tgt_lengths=*/
tgt_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
(
gradients
==
c10
::
nullopt
)
?
nullptr
:
gradients
->
data_ptr
<
c10
::
Half
>
());
break
;
}
default:
{
break
;
}
};
return
std
::
make_tuple
(
costs
,
gradients
);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
&
compute
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
0 → 100644
View file @
5417e4fb
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_alphas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src_lengths
,
const
torch
::
Tensor
&
tgt_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
alphas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*src_lengths=*/
src_lengths
.
data_ptr
<
int
>
(),
/*tgt_lengths=*/
tgt_lengths
.
data_ptr
<
int
>
(),
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
return
alphas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_alphas"
,
&
compute_alphas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/compute_betas.cu
0 → 100644
View file @
5417e4fb
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_betas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
src_lengths
,
const
torch
::
Tensor
&
tgt_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
src_lengths
.
size
(
0
);
options
.
nHypos_
=
tgt_lengths
.
size
(
0
)
/
src_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
tgt_lengths
.
size
(
0
),
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
betas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*src_lengths=*/
src_lengths
.
data_ptr
<
int
>
(),
/*tgt_lengths=*/
tgt_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*betas=*/
betas
.
data_ptr
<
float
>
());
return
betas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_betas"
,
&
compute_betas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
0 → 100644
View file @
5417e4fb
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
// each thread reduces one matrix row
int
offset
=
blockIdx
.
x
*
dim
;
// [n, 0]
CAST_DTYPE
val
=
inputs
[
offset
];
// default = inputs(n, 0)
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
CAST_DTYPE
next
=
inputs
[
offset
+
d
];
if
(
next
>
val
)
{
val
=
next
;
}
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shared
[
threadIdx
.
x
+
stride
]
>
shared
[
threadIdx
.
x
])
{
shared
[
threadIdx
.
x
]
=
shared
[
threadIdx
.
x
+
stride
];
val
=
shared
[
threadIdx
.
x
];
}
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down_sync
(
0xFFFFFFFF
,
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shf
>
val
)
{
val
=
shf
;
}
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
val
;
}
}
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceLogSumExpGivenMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
// in: max -> out: logsum
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
CAST_DTYPE
max
=
outputs
[
blockIdx
.
x
];
CAST_DTYPE
val
=
0
;
int
offset
=
blockIdx
.
x
*
dim
;
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
val
=
val
+
std
::
exp
(
CAST_DTYPE
(
inputs
[
offset
+
d
])
-
max
);
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
shared
[
threadIdx
.
x
]
+
shared
[
threadIdx
.
x
+
stride
];
shared
[
threadIdx
.
x
]
=
val
;
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down_sync
(
0xFFFFFFFF
,
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
val
+
shf
;
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
max
+
std
::
log
(
val
);
}
}
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_CUDA
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
0 → 100644
View file @
5417e4fb
#pragma once
#ifdef USE_CUDA
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/kernels.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeLogProbs
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
CAST_DTYPE
*
logProbs
,
int
H
=
1
,
bool
fusedLogSmax
=
true
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
Indexer3D
indexer
(
maxT
,
maxU
);
int
idx
=
indexer
(
bTgt
,
t
,
u
);
// skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u).
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
])
-
denominators
[
idx
];
if
(
!
fusedLogSmax
)
{
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
]);
}
if
(
u
<
U
-
1
)
{
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int
target
=
targets
[
Indexer2D
(
maxU
-
1
)(
bTgt
,
u
)];
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
target
])
-
denominators
[
idx
];
if
(
!
fusedLogSmax
)
{
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
target
]);
}
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeAlphas
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
+
1
;
const
int
u
=
blockIdx
.
y
+
1
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
int
*
counter
=
alpha_counters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
1
&&
u
==
1
)
{
alphas
[
idxr
(
bTgt
,
0
,
0
)]
=
0
;
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
1
&&
u
<
U
)
{
// alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit().
alphas
[
idxr
(
bTgt
,
0
,
u
)]
=
alphas
[
idxr
(
bTgt
,
0
,
u
-
1
)]
+
logProbs
[(
idxr
(
bTgt
,
0
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
<
T
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
0
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
val
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
0
)];
alphas
[
idxr
(
bTgt
,
t
,
0
)]
=
skip_prob
+
val
;
}
if
(
t
<
T
&&
u
<
U
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
alphas
[
idxr
(
bTgt
,
t
,
u
-
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
alphas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
T
-
2
-
blockIdx
.
x
*
blockDim
.
x
-
threadIdx
.
x
;
const
int
u
=
U
-
2
-
blockIdx
.
y
;
if
(
t
<
0
||
u
<
0
)
{
// out of boundary.
return
;
}
int
*
counter
=
betaCounters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
T
-
2
&&
u
==
U
-
2
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
U
-
1
)]
=
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
T
-
2
&&
u
>=
0
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
u
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
,
u
+
1
)]
+
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
U
-
1
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
-
blockIdx
.
x
*
blockDim
.
x
,
U
-
1
)]
+
skip_prob
;
}
if
(
t
>=
0
&&
u
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
betas
[
idxr
(
bTgt
,
t
+
threadIdx
.
x
+
1
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
betas
[
idxr
(
bTgt
,
t
,
u
+
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
if
(
t
==
0
&&
u
==
0
)
{
// use -beta(0, 0) as cost.
costs
[
bTgt
]
=
DTYPE
(
-
out
);
}
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
warpSize
=
0
,
int
numWarps
=
0
,
int
H
=
1
)
{
assert
(
threadIdx
.
y
==
0
||
threadIdx
.
y
==
1
);
if
(
threadIdx
.
y
==
0
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
alpha_counters
,
/*alphas=*/
alphas
,
H
);
}
else
{
// threadIdx.y == 1
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*betaCounters=*/
betaCounters
,
/*beta=*/
betas
,
/*costs=*/
costs
,
H
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeGradients
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
CAST_DTYPE
clamp
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
int
H
=
1
,
bool
fusedLogSmax
=
true
)
{
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
ComputeGradientsElement
(
bTgt
,
t
,
u
,
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
clamp
,
logits
,
targets
,
srcLengths
,
tgtLengths
,
denominators
,
alphas
,
betas
,
gradients
,
H
,
fusedLogSmax
);
}
// This is a __global__ wrapper around ComputeAlphas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
alpha_counters
,
alphas
,
H
);
}
// This is a __global__ wrapper around ComputeBetas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeBetasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
betaCounters
,
betas
,
costs
,
H
);
}
// #undef LOG_PROBS_SKIP_IDX
// #undef LOG_PROBS_EMIT_IDX
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_CUDA
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
0 → 100644
View file @
5417e4fb
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline
void
gpuAssert
(
cudaError_t
code
,
const
char
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
cudaSuccess
)
{
fprintf
(
stderr
,
"
\n
GPUassert: %s %s %d
\n
"
,
cudaGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
exit
(
code
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
LogSumExp2D
(
cudaStream_t
stream
,
int
N
,
int
D
,
const
DTYPE
*
logits
,
// [N, D]
CAST_DTYPE
*
outputs
)
{
{
// compute max among D.
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED
;
}
}
{
// compute log(sum(exp(d_i - max)))
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceLogSumExpGivenMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED
;
}
}
return
SUCCESS
;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
Compute
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
gradients
=
nullptr
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
cudaStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
const
CAST_DTYPE
clamp
=
options
.
clamp_
;
const
bool
&
fusedLogSmax
=
options
.
fusedLogSmax_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
,
fusedLogSmax
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas, betas and costs.
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3
thread_dims
(
WARP_SIZE
,
2
);
ComputeAlphasBetasCosts
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*beta_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*costs=*/
costs
,
/*warp_size=*/
WARP_SIZE
,
/*num_warps=*/
num_warps
,
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
if
(
gradients
!=
nullptr
)
{
// compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int
num_blocks
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_blocks
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeGradients
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*clamp=*/
clamp
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*gradients=*/
gradients
,
H
,
fusedLogSmax
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_GRADIENTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeAlphas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
alphas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
cudaStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeAlphasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
alphas
,
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeBetas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
betas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
cudaStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute betas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeBetasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
betas
,
costs
,
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_CUDA
torchaudio/csrc/rnnt/gpu/half.cuh
0 → 100644
View file @
5417e4fb
#pragma once
#ifdef USE_C10_HALF
#include "c10/util/Half.h"
#endif // USE_C10_HALF
#include <torchaudio/csrc/rnnt/macros.h>
namespace
torchaudio
{
namespace
rnnt
{
struct
alignas
(
sizeof
(
__half
))
Half
{
__half
x
;
HOST_AND_DEVICE
Half
()
=
default
;
FORCE_INLINE
HOST_AND_DEVICE
Half
(
float
f
)
{
x
=
__float2half_rn
(
f
);
if
(
isinf
(
__half2float
(
x
)))
{
x
=
__float2half_rz
(
f
);
// round toward 0.
}
}
FORCE_INLINE
HOST_AND_DEVICE
operator
float
()
const
{
return
__half2float
(
x
);
}
FORCE_INLINE
HOST_AND_DEVICE
Half
(
__half
f
)
{
x
=
f
;
}
FORCE_INLINE
HOST_AND_DEVICE
operator
__half
()
const
{
return
x
;
}
};
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/kernel_utils.h
0 → 100644
View file @
5417e4fb
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
inline
HOST_AND_DEVICE
bool
in_range
(
int
start
,
int
end
,
// inclusive
int
val
)
{
return
start
<=
val
&&
val
<=
end
;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct
Indexer2D
{
const
int
&
size2_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer2D
(
const
int
&
size2
)
:
size2_
(
size2
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
)
{
return
index1
*
size2_
+
index2
;
}
};
struct
Indexer3D
{
const
int
&
size2_
;
const
int
&
size3_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer3D
(
const
int
&
size2
,
const
int
&
size3
)
:
size2_
(
size2
),
size3_
(
size3
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
)
{
return
(
index1
*
size2_
+
index2
)
*
size3_
+
index3
;
}
};
struct
Indexer4D
{
const
int
&
size2_
;
const
int
&
size3_
;
const
int
&
size4_
;
HOST_AND_DEVICE
Indexer4D
(
const
int
&
size2
,
const
int
&
size3
,
const
int
&
size4
)
:
size2_
(
size2
),
size3_
(
size3
),
size4_
(
size4
)
{}
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
,
int
index4
)
{
return
((
index1
*
size2_
+
index2
)
*
size3_
+
index3
)
*
size4_
+
index4
;
}
};
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/kernels.h
0 → 100644
View file @
5417e4fb
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
HOST_AND_DEVICE
void
ComputeGradientsElement
(
int
bTgt
,
int
t
,
int
u
,
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
CAST_DTYPE
clamp
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
int
H
=
1
,
bool
fusedLogSmax
=
true
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
if
(
gradients
==
logits
&&
t
<
maxT
&&
u
<
maxU
)
{
// gradients and logits are pointing to the same memory location
Indexer3D
idxr3
(
maxT
,
maxU
);
int
idx_b_t_u_zero
=
idxr3
(
bTgt
,
t
,
u
);
if
(
idx_b_t_u_zero
!=
-
1
)
{
int
start
=
idx_b_t_u_zero
*
D
;
for
(
int
b_t_u_d
=
start
;
b_t_u_d
<
start
+
D
;
++
b_t_u_d
)
{
gradients
[
b_t_u_d
]
=
0
;
}
}
}
return
;
}
int
costIdx
=
bTgt
*
maxT
*
maxU
;
CAST_DTYPE
cost
=
-
(
betas
[
costIdx
]);
Indexer2D
idxr2
(
maxU
-
1
);
int
idx_b_t_u
,
idx_b_t_up1
,
idx_b_tp1_u
;
Indexer3D
idxr3
(
maxT
,
maxU
);
idx_b_t_u
=
idxr3
(
bTgt
,
t
,
u
);
idx_b_t_up1
=
idxr3
(
bTgt
,
t
,
u
+
1
);
idx_b_tp1_u
=
idxr3
(
bTgt
,
t
+
1
,
u
);
if
(
idx_b_t_u
==
-
1
)
{
return
;
}
if
(
isinf
(
cost
)
||
isnan
(
cost
))
{
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
gradients
[
b_t_u_d
]
=
0
;
}
return
;
}
CAST_DTYPE
c
=
alphas
[
idx_b_t_u
]
+
cost
-
denominators
[
idx_b_t_u
];
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
CAST_DTYPE
g
=
CAST_DTYPE
(
logits
[
b_t_u_d
])
+
c
;
if
(
fusedLogSmax
)
{
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
// last blank transition.
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
])
-
std
::
exp
(
g
);
}
else
if
(
t
<
T
-
1
&&
d
==
blank
)
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_tp1_u
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_tp1_u
]);
}
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
idxr2
(
bTgt
,
u
)])
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_t_up1
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_t_up1
]);
}
}
else
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
}
}
else
{
// Non fused log softmax case
CAST_DTYPE
g
=
cost
+
CAST_DTYPE
(
logits
[
b_t_u_d
]);
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
gradients
[
b_t_u_d
]
=
g
+
alphas
[
idx_b_t_u
];
}
else
if
(
t
<
T
-
1
&&
d
==
blank
)
{
if
(
idx_b_tp1_u
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
g
+
alphas
[
idx_b_t_u
]
+
betas
[
idx_b_tp1_u
];
}
else
{
gradients
[
b_t_u_d
]
=
g
+
CAST_DTYPE
(
-
INFINITY
);
}
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
idxr2
(
bTgt
,
u
)])
{
if
(
idx_b_t_up1
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
g
+
alphas
[
idx_b_t_u
]
+
betas
[
idx_b_t_up1
];
}
else
{
gradients
[
b_t_u_d
]
=
g
+
CAST_DTYPE
(
-
INFINITY
);
}
}
else
{
gradients
[
b_t_u_d
]
=
g
+
CAST_DTYPE
(
-
INFINITY
);
}
gradients
[
b_t_u_d
]
=
-
std
::
exp
(
gradients
[
b_t_u_d
]);
}
if
(
clamp
>
0
)
{
auto
g
=
CAST_DTYPE
(
gradients
[
b_t_u_d
]);
gradients
[
b_t_u_d
]
=
math
::
min
(
g
,
clamp
);
gradients
[
b_t_u_d
]
=
math
::
max
(
g
,
-
clamp
);
}
}
}
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/math.cuh
0 → 100644
View file @
5417e4fb
#pragma once
#ifdef USE_CUDA
#include <cmath>
#endif // USE_CUDA
#include <torchaudio/csrc/rnnt/gpu/half.cuh>
namespace
torchaudio
{
namespace
rnnt
{
namespace
math
{
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
max
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
x
;
else
return
y
;
}
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
min
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
y
;
else
return
x
;
}
// log_sum_exp
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
lse
(
DTYPE
x
,
DTYPE
y
);
template
<
>
FORCE_INLINE
HOST_AND_DEVICE
float
lse
(
float
x
,
float
y
)
{
if
(
y
>
x
)
{
return
y
+
log1pf
(
expf
(
x
-
y
));
}
else
{
return
x
+
log1pf
(
expf
(
y
-
x
));
}
}
}
// namespace math
}
// namespace rnnt
}
// namespace torchaudio
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment