Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b6c4b068
Commit
b6c4b068
authored
Dec 13, 2022
by
flyingdown
Browse files
rnnt for dcu
parent
f5d79493
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1525 additions
and
4 deletions
+1525
-4
torchaudio/csrc/CMakeLists.txt
torchaudio/csrc/CMakeLists.txt
+25
-0
torchaudio/csrc/rnnt/dcu/compute.cpp
torchaudio/csrc/rnnt/dcu/compute.cpp
+151
-0
torchaudio/csrc/rnnt/dcu/compute_alphas.cpp
torchaudio/csrc/rnnt/dcu/compute_alphas.cpp
+73
-0
torchaudio/csrc/rnnt/dcu/compute_betas.cpp
torchaudio/csrc/rnnt/dcu/compute_betas.cpp
+78
-0
torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh
torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh
+98
-0
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
+409
-0
torchaudio/csrc/rnnt/dcu/gpu_transducer.h
torchaudio/csrc/rnnt/dcu/gpu_transducer.h
+391
-0
torchaudio/csrc/rnnt/dcu/half.cuh
torchaudio/csrc/rnnt/dcu/half.cuh
+38
-0
torchaudio/csrc/rnnt/dcu/kernel_utils.h
torchaudio/csrc/rnnt/dcu/kernel_utils.h
+66
-0
torchaudio/csrc/rnnt/dcu/kernels.h
torchaudio/csrc/rnnt/dcu/kernels.h
+108
-0
torchaudio/csrc/rnnt/dcu/math.cuh
torchaudio/csrc/rnnt/dcu/math.cuh
+48
-0
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
+1
-1
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
+8
-0
torchaudio/csrc/rnnt/macros.h
torchaudio/csrc/rnnt/macros.h
+9
-1
torchaudio/csrc/rnnt/options.h
torchaudio/csrc/rnnt/options.h
+8
-0
torchaudio/csrc/rnnt/workspace.h
torchaudio/csrc/rnnt/workspace.h
+14
-2
No files found.
torchaudio/csrc/CMakeLists.txt
View file @
b6c4b068
...
...
@@ -31,6 +31,15 @@ if(BUILD_RNNT)
rnnt/gpu/compute.cu
)
endif
()
if
(
USE_ROCM
)
list
(
APPEND
LIBTORCHAUDIO_SOURCES
rnnt/dcu/compute_alphas.cpp
rnnt/dcu/compute_betas.cpp
rnnt/dcu/compute.cpp
)
endif
()
endif
()
if
(
BUILD_KALDI
)
...
...
@@ -49,6 +58,7 @@ if(BUILD_SOX)
)
endif
()
message
(
status
"
${
LIBTORCHAUDIO_SOURCES
}
"
)
add_library
(
libtorchaudio
SHARED
...
...
@@ -78,6 +88,21 @@ endif()
if
(
USE_CUDA
)
target_compile_definitions
(
libtorchaudio PRIVATE USE_CUDA
)
target_include_directories
(
libtorchaudio
PRIVATE
${
CUDA_TOOLKIT_INCLUDE
}
rnnt
)
target_link_libraries
(
libtorchaudio
${
C10_CUDA_LIBRARY
}
${
CUDA_CUDART_LIBRARY
}
)
endif
()
if
(
USE_ROCM
)
target_compile_definitions
(
libtorchaudio PRIVATE USE_ROCM
)
target_include_directories
(
libtorchaudio
PRIVATE
...
...
torchaudio/csrc/rnnt/dcu/compute.cpp
0 → 100644
View file @
b6c4b068
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
// Entry point into RNNT Loss
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
logit_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
target_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
logit_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
target_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
logit_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
target_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
logit_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
target_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
logit_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
target_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
logit_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
target_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
target_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
hip
::
getCurrentHIPStream
();
hipSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
options
.
batchSize_
*
options
.
nHypos_
,
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
torch
::
zeros_like
(
logits
);
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
switch
(
logits
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
{
Compute
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
gradients
->
data_ptr
<
float
>
());
break
;
}
case
torch
::
ScalarType
::
Half
:
{
Compute
<
/*DTYPE=*/
c10
::
Half
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
gradients
->
data_ptr
<
c10
::
Half
>
());
break
;
}
default:
{
break
;
}
};
return
std
::
make_tuple
(
costs
,
gradients
);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
&
compute
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/compute_alphas.cpp
0 → 100644
View file @
b6c4b068
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_alphas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
hip
::
getCurrentHIPStream
();
hipSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
alphas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
return
alphas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_alphas"
,
&
compute_alphas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/compute_betas.cpp
0 → 100644
View file @
b6c4b068
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_betas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
hip
::
getCurrentHIPStream
();
hipSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
target_lengths
.
size
(
0
),
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
betas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*betas=*/
betas
.
data_ptr
<
float
>
());
return
betas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_betas"
,
&
compute_betas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh
0 → 100644
View file @
b6c4b068
#pragma once
#ifdef USE_ROCM
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
// each thread reduces one matrix row
int
offset
=
blockIdx
.
x
*
dim
;
// [n, 0]
CAST_DTYPE
val
=
inputs
[
offset
];
// default = inputs(n, 0)
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
CAST_DTYPE
next
=
inputs
[
offset
+
d
];
if
(
next
>
val
)
{
val
=
next
;
}
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shared
[
threadIdx
.
x
+
stride
]
>
shared
[
threadIdx
.
x
])
{
shared
[
threadIdx
.
x
]
=
shared
[
threadIdx
.
x
+
stride
];
val
=
shared
[
threadIdx
.
x
];
}
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down
(
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shf
>
val
)
{
val
=
shf
;
}
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
val
;
}
}
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceLogSumExpGivenMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
// in: max -> out: logsum
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
CAST_DTYPE
max
=
outputs
[
blockIdx
.
x
];
CAST_DTYPE
val
=
0
;
int
offset
=
blockIdx
.
x
*
dim
;
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
val
=
val
+
std
::
exp
(
CAST_DTYPE
(
inputs
[
offset
+
d
])
-
max
);
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
shared
[
threadIdx
.
x
]
+
shared
[
threadIdx
.
x
+
stride
];
shared
[
threadIdx
.
x
]
=
val
;
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down
(
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
val
+
shf
;
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
max
+
std
::
log
(
val
);
}
}
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_ROCM
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
0 → 100644
View file @
b6c4b068
#pragma once
#ifdef USE_ROCM
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/dcu/kernels.h>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeLogProbs
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
CAST_DTYPE
*
logProbs
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
Indexer3D
indexer
(
maxT
,
maxU
);
int
idx
=
indexer
(
bTgt
,
t
,
u
);
// skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u).
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
])
-
denominators
[
idx
];
if
(
u
<
U
-
1
)
{
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int
target
=
targets
[
Indexer2D
(
maxU
-
1
)(
bTgt
,
u
)];
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
target
])
-
denominators
[
idx
];
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeAlphas
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
+
1
;
const
int
u
=
blockIdx
.
y
+
1
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
int
*
counter
=
alpha_counters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
1
&&
u
==
1
)
{
alphas
[
idxr
(
bTgt
,
0
,
0
)]
=
0
;
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
1
&&
u
<
U
)
{
// alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit().
alphas
[
idxr
(
bTgt
,
0
,
u
)]
=
alphas
[
idxr
(
bTgt
,
0
,
u
-
1
)]
+
logProbs
[(
idxr
(
bTgt
,
0
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
<
T
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
0
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
val
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
0
)];
alphas
[
idxr
(
bTgt
,
t
,
0
)]
=
skip_prob
+
val
;
}
if
(
t
<
T
&&
u
<
U
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
alphas
[
idxr
(
bTgt
,
t
,
u
-
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
alphas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
T
-
2
-
blockIdx
.
x
*
blockDim
.
x
-
threadIdx
.
x
;
const
int
u
=
U
-
2
-
blockIdx
.
y
;
if
(
t
<
0
||
u
<
0
)
{
// out of boundary.
return
;
}
int
*
counter
=
betaCounters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
T
-
2
&&
u
==
U
-
2
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
U
-
1
)]
=
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
T
-
2
&&
u
>=
0
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
u
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
,
u
+
1
)]
+
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
U
-
1
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
-
blockIdx
.
x
*
blockDim
.
x
,
U
-
1
)]
+
skip_prob
;
}
if
(
t
>=
0
&&
u
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
betas
[
idxr
(
bTgt
,
t
+
threadIdx
.
x
+
1
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
betas
[
idxr
(
bTgt
,
t
,
u
+
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
if
(
t
==
0
&&
u
==
0
)
{
// use -beta(0, 0) as cost.
costs
[
bTgt
]
=
DTYPE
(
-
out
);
}
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
warpSize
=
0
,
int
numWarps
=
0
,
int
H
=
1
)
{
assert
(
threadIdx
.
y
==
0
||
threadIdx
.
y
==
1
);
if
(
threadIdx
.
y
==
0
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
alpha_counters
,
/*alphas=*/
alphas
,
H
);
}
else
{
// threadIdx.y == 1
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*betaCounters=*/
betaCounters
,
/*beta=*/
betas
,
/*costs=*/
costs
,
H
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeGradients
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
CAST_DTYPE
clamp
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
int
H
=
1
)
{
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
ComputeGradientsElement
(
bTgt
,
t
,
u
,
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
clamp
,
logits
,
targets
,
srcLengths
,
tgtLengths
,
denominators
,
alphas
,
betas
,
gradients
,
H
);
}
// This is a __global__ wrapper around ComputeAlphas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
alpha_counters
,
alphas
,
H
);
}
// This is a __global__ wrapper around ComputeBetas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeBetasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
betaCounters
,
betas
,
costs
,
H
);
}
// #undef LOG_PROBS_SKIP_IDX
// #undef LOG_PROBS_EMIT_IDX
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_ROCM
torchaudio/csrc/rnnt/dcu/gpu_transducer.h
0 → 100644
View file @
b6c4b068
#pragma once
#ifdef USE_ROCM
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline
void
gpuAssert
(
hipError_t
code
,
const
char
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
hipSuccess
)
{
fprintf
(
stderr
,
"
\n
GPUassert: %s %s %d
\n
"
,
hipGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
exit
(
code
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
LogSumExp2D
(
hipStream_t
stream
,
int
N
,
int
D
,
const
DTYPE
*
logits
,
// [N, D]
CAST_DTYPE
*
outputs
)
{
{
// compute max among D.
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED
;
}
}
{
// compute log(sum(exp(d_i - max)))
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceLogSumExpGivenMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED
;
}
}
return
SUCCESS
;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
Compute
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
gradients
=
nullptr
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
hipStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
const
CAST_DTYPE
clamp
=
options
.
clamp_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas, betas and costs.
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3
thread_dims
(
WARP_SIZE
,
2
);
ComputeAlphasBetasCosts
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*beta_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*costs=*/
costs
,
/*warp_size=*/
WARP_SIZE
,
/*num_warps=*/
num_warps
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
if
(
gradients
!=
nullptr
)
{
// compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int
num_blocks
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_blocks
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeGradients
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*clamp=*/
clamp
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*gradients=*/
gradients
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_GRADIENTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeAlphas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
alphas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
hipStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeAlphasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
alphas
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeBetas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
betas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
hipStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute betas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeBetasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
betas
,
costs
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_ROCM
torchaudio/csrc/rnnt/dcu/half.cuh
0 → 100644
View file @
b6c4b068
#pragma once
#ifdef USE_C10_HALF
#include "c10/util/Half.h"
#endif // USE_C10_HALF
#include <torchaudio/csrc/rnnt/macros.h>
namespace
torchaudio
{
namespace
rnnt
{
struct
alignas
(
sizeof
(
__half
))
Half
{
__half
x
;
HOST_AND_DEVICE
Half
()
=
default
;
FORCE_INLINE
HOST_AND_DEVICE
Half
(
float
f
)
{
x
=
__float2half_rn
(
f
);
if
(
isinf
(
__half2float
(
x
)))
{
x
=
__float2half_rz
(
f
);
// round toward 0.
}
}
FORCE_INLINE
HOST_AND_DEVICE
operator
float
()
const
{
return
__half2float
(
x
);
}
FORCE_INLINE
HOST_AND_DEVICE
Half
(
__half
f
)
{
x
=
f
;
}
FORCE_INLINE
HOST_AND_DEVICE
operator
__half
()
const
{
return
x
;
}
};
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/kernel_utils.h
0 → 100644
View file @
b6c4b068
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
inline
HOST_AND_DEVICE
bool
in_range
(
int
start
,
int
end
,
// inclusive
int
val
)
{
return
start
<=
val
&&
val
<=
end
;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct
Indexer2D
{
const
int
&
size2_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer2D
(
const
int
&
size2
)
:
size2_
(
size2
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
)
{
return
index1
*
size2_
+
index2
;
}
};
struct
Indexer3D
{
const
int
&
size2_
;
const
int
&
size3_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer3D
(
const
int
&
size2
,
const
int
&
size3
)
:
size2_
(
size2
),
size3_
(
size3
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
)
{
return
(
index1
*
size2_
+
index2
)
*
size3_
+
index3
;
}
};
struct
Indexer4D
{
const
int
&
size2_
;
const
int
&
size3_
;
const
int
&
size4_
;
HOST_AND_DEVICE
Indexer4D
(
const
int
&
size2
,
const
int
&
size3
,
const
int
&
size4
)
:
size2_
(
size2
),
size3_
(
size3
),
size4_
(
size4
)
{}
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
,
int
index4
)
{
return
((
index1
*
size2_
+
index2
)
*
size3_
+
index3
)
*
size4_
+
index4
;
}
};
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/kernels.h
0 → 100644
View file @
b6c4b068
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
HOST_AND_DEVICE
void
ComputeGradientsElement
(
int
bTgt
,
int
t
,
int
u
,
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
CAST_DTYPE
clamp
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
if
(
gradients
==
logits
&&
t
<
maxT
&&
u
<
maxU
)
{
// gradients and logits are pointing to the same memory location
Indexer3D
idxr3
(
maxT
,
maxU
);
int
idx_b_t_u_zero
=
idxr3
(
bTgt
,
t
,
u
);
if
(
idx_b_t_u_zero
!=
-
1
)
{
int
start
=
idx_b_t_u_zero
*
D
;
for
(
int
b_t_u_d
=
start
;
b_t_u_d
<
start
+
D
;
++
b_t_u_d
)
{
gradients
[
b_t_u_d
]
=
0
;
}
}
}
return
;
}
int
costIdx
=
bTgt
*
maxT
*
maxU
;
CAST_DTYPE
cost
=
-
(
betas
[
costIdx
]);
Indexer2D
idxr2
(
maxU
-
1
);
int
idx_b_t_u
,
idx_b_t_up1
,
idx_b_tp1_u
;
Indexer3D
idxr3
(
maxT
,
maxU
);
idx_b_t_u
=
idxr3
(
bTgt
,
t
,
u
);
idx_b_t_up1
=
idxr3
(
bTgt
,
t
,
u
+
1
);
idx_b_tp1_u
=
idxr3
(
bTgt
,
t
+
1
,
u
);
if
(
idx_b_t_u
==
-
1
)
{
return
;
}
if
(
isinf
(
cost
)
||
isnan
(
cost
))
{
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
gradients
[
b_t_u_d
]
=
0
;
}
return
;
}
CAST_DTYPE
c
=
alphas
[
idx_b_t_u
]
+
cost
-
denominators
[
idx_b_t_u
];
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
CAST_DTYPE
g
=
CAST_DTYPE
(
logits
[
b_t_u_d
])
+
c
;
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
// last blank transition.
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
])
-
std
::
exp
(
g
);
}
else
if
(
t
<
T
-
1
&&
d
==
blank
)
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_tp1_u
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_tp1_u
]);
}
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
idxr2
(
bTgt
,
u
)])
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_t_up1
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_t_up1
]);
}
}
else
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
}
if
(
clamp
>
0
)
{
auto
g
=
CAST_DTYPE
(
gradients
[
b_t_u_d
]);
gradients
[
b_t_u_d
]
=
math
::
min
(
g
,
clamp
);
gradients
[
b_t_u_d
]
=
math
::
max
(
g
,
-
clamp
);
}
}
}
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/math.cuh
0 → 100644
View file @
b6c4b068
#pragma once
#ifdef USE_ROCM
#include <cmath>
#endif // USE_ROCM
#include <torchaudio/csrc/rnnt/dcu/half.cuh>
namespace
torchaudio
{
namespace
rnnt
{
namespace
math
{
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
max
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
x
;
else
return
y
;
}
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
min
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
y
;
else
return
x
;
}
// log_sum_exp
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
lse
(
DTYPE
x
,
DTYPE
y
);
template
<
>
FORCE_INLINE
HOST_AND_DEVICE
float
lse
(
float
x
,
float
y
)
{
if
(
y
>
x
)
{
return
y
+
log1pf
(
expf
(
x
-
y
));
}
else
{
return
x
+
log1pf
(
expf
(
y
-
x
));
}
}
}
// namespace math
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
View file @
b6c4b068
#include <
c10/cuda/CUDAStream
.h>
#include <
ATen/hip/impl/HIPStreamMasqueradingAsCUDA
.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
...
...
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
View file @
b6c4b068
...
...
@@ -39,7 +39,11 @@ __global__ void ReduceMax2D(
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
#if defined(USE_ROCM)
shf
=
__shfl_down
(
val
,
stride
);
#else
shf
=
__shfl_down_sync
(
0xFFFFFFFF
,
val
,
stride
);
#endif
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shf
>
val
)
{
val
=
shf
;
...
...
@@ -81,7 +85,11 @@ __global__ void ReduceLogSumExpGivenMax2D(
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
#if defined(USE_ROCM)
shf
=
__shfl_down
(
val
,
stride
);
#else
shf
=
__shfl_down_sync
(
0xFFFFFFFF
,
val
,
stride
);
#endif
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
val
+
shf
;
}
...
...
torchaudio/csrc/rnnt/macros.h
View file @
b6c4b068
#pragma once
#ifdef
USE_CUDA
#if
def
ined(
USE_CUDA
)
#define WARP_SIZE 32
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
...
...
@@ -8,6 +8,14 @@
#define FORCE_INLINE __forceinline__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#elif defined(USE_ROCM)
#define WARP_SIZE 64
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
#define HOST_AND_DEVICE __host__ __device__
#define FORCE_INLINE __forceinline__
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#else
#define HOST_AND_DEVICE
#define FORCE_INLINE inline
...
...
torchaudio/csrc/rnnt/options.h
View file @
b6c4b068
...
...
@@ -6,6 +6,10 @@
#include <cuda_runtime.h>
#endif // USE_CUDA
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif // USE_ROCM
#include <torchaudio/csrc/rnnt/macros.h>
#include <torchaudio/csrc/rnnt/types.h>
...
...
@@ -18,6 +22,10 @@ typedef struct Options {
#ifdef USE_CUDA
// the stream to launch kernels in when using GPU.
cudaStream_t
stream_
;
#endif
#ifdef USE_ROCM
// the stream to launch kernels in when using GPU.
hipStream_t
stream_
;
#endif
// The maximum number of threads that can be used.
int
numThreads_
;
...
...
torchaudio/csrc/rnnt/workspace.h
View file @
b6c4b068
...
...
@@ -131,10 +131,22 @@ class IntWorkspace {
ComputeSizeForBetaCounters
(
options_
)
*
sizeof
(
int
));
}
#endif // USE_CUDA
#ifdef USE_ROCM
if
(
data_
!=
nullptr
&&
options_
.
device_
==
GPU
)
{
hipMemset
(
GetPointerToAlphaCounters
(),
0
,
ComputeSizeForAlphaCounters
(
options_
)
*
sizeof
(
int
));
hipMemset
(
GetPointerToBetaCounters
(),
0
,
ComputeSizeForBetaCounters
(
options_
)
*
sizeof
(
int
));
}
#endif // USE_ROCM
}
static
int
ComputeSizeForAlphaCounters
(
const
Options
&
options
)
{
// B * U
#ifdef
USE_CUDA
#if
def
ined(
USE_CUDA
) || defined(USE_ROCM)
if
(
options
.
device_
==
GPU
)
{
return
options
.
BU
();
}
else
{
...
...
@@ -145,7 +157,7 @@ class IntWorkspace {
#endif // USE_CUDA
}
static
int
ComputeSizeForBetaCounters
(
const
Options
&
options
)
{
// B * U
#ifdef
USE_CUDA
#if
def
ined(
USE_CUDA
) || defined(USE_ROCM)
if
(
options
.
device_
==
GPU
)
{
return
options
.
BU
();
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment