Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
d946a7c8
Commit
d946a7c8
authored
May 05, 2023
by
flyingdown
Browse files
rnnt
parent
b90d7988
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
1525 additions
and
8 deletions
+1525
-8
torchaudio/csrc/CMakeLists.txt
torchaudio/csrc/CMakeLists.txt
+29
-0
torchaudio/csrc/rnnt/dcu/compute.cpp
torchaudio/csrc/rnnt/dcu/compute.cpp
+151
-0
torchaudio/csrc/rnnt/dcu/compute_alphas.cpp
torchaudio/csrc/rnnt/dcu/compute_alphas.cpp
+73
-0
torchaudio/csrc/rnnt/dcu/compute_betas.cpp
torchaudio/csrc/rnnt/dcu/compute_betas.cpp
+78
-0
torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh
torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh
+98
-0
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
+409
-0
torchaudio/csrc/rnnt/dcu/gpu_transducer.h
torchaudio/csrc/rnnt/dcu/gpu_transducer.h
+391
-0
torchaudio/csrc/rnnt/dcu/half.cuh
torchaudio/csrc/rnnt/dcu/half.cuh
+38
-0
torchaudio/csrc/rnnt/dcu/kernel_utils.h
torchaudio/csrc/rnnt/dcu/kernel_utils.h
+66
-0
torchaudio/csrc/rnnt/dcu/kernels.h
torchaudio/csrc/rnnt/dcu/kernels.h
+108
-0
torchaudio/csrc/rnnt/dcu/math.cuh
torchaudio/csrc/rnnt/dcu/math.cuh
+48
-0
torchaudio/csrc/rnnt/macros.h
torchaudio/csrc/rnnt/macros.h
+9
-1
torchaudio/csrc/rnnt/options.h
torchaudio/csrc/rnnt/options.h
+8
-0
torchaudio/csrc/rnnt/workspace.h
torchaudio/csrc/rnnt/workspace.h
+19
-7
No files found.
torchaudio/csrc/CMakeLists.txt
View file @
d946a7c8
...
...
@@ -51,6 +51,16 @@ if(BUILD_RNNT)
rnnt/gpu/compute.cu
)
endif
()
if
(
USE_ROCM
)
list
(
APPEND
LIBTORCHAUDIO_SOURCES
rnnt/dcu/compute_alphas.cpp
rnnt/dcu/compute_betas.cpp
rnnt/dcu/compute.cpp
)
endif
()
endif
()
if
(
USE_CUDA
)
...
...
@@ -72,6 +82,25 @@ if(USE_CUDA)
)
endif
()
if
(
USE_ROCM
)
list
(
APPEND
LIBTORCHAUDIO_INCLUDE_DIRS
${
CUDA_TOOLKIT_INCLUDE
}
)
list
(
APPEND
LIBTORCHAUDIO_LINK_LIBRARIES
${
C10_CUDA_LIBRARY
}
${
CUDA_CUDART_LIBRARY
}
)
list
(
APPEND
LIBTORCHAUDIO_COMPILE_DEFINITIONS
USE_ROCM
)
endif
()
if
(
BUILD_KALDI
)
list
(
APPEND LIBTORCHAUDIO_LINK_LIBRARIES kaldi
)
list
(
APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp
)
...
...
torchaudio/csrc/rnnt/dcu/compute.cpp
0 → 100644
View file @
d946a7c8
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
// Entry point into RNNT Loss
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
logit_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
target_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
logit_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
target_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
logit_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
target_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
logit_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
target_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
logit_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
target_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
logit_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
target_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
target_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
hip
::
getCurrentHIPStream
();
hipSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
options
.
batchSize_
*
options
.
nHypos_
,
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
torch
::
zeros_like
(
logits
);
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
switch
(
logits
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
{
Compute
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
gradients
->
data_ptr
<
float
>
());
break
;
}
case
torch
::
ScalarType
::
Half
:
{
Compute
<
/*DTYPE=*/
c10
::
Half
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
gradients
->
data_ptr
<
c10
::
Half
>
());
break
;
}
default:
{
break
;
}
};
return
std
::
make_tuple
(
costs
,
gradients
);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
&
compute
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/compute_alphas.cpp
0 → 100644
View file @
d946a7c8
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_alphas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
hip
::
getCurrentHIPStream
();
hipSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
alphas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
return
alphas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_alphas"
,
&
compute_alphas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/compute_betas.cpp
0 → 100644
View file @
d946a7c8
#include <c10/hip/HIPStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_betas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
hip
::
getCurrentHIPStream
();
hipSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
target_lengths
.
size
(
0
),
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
betas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*betas=*/
betas
.
data_ptr
<
float
>
());
return
betas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_betas"
,
&
compute_betas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh
0 → 100644
View file @
d946a7c8
#pragma once
#ifdef USE_ROCM
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
// each thread reduces one matrix row
int
offset
=
blockIdx
.
x
*
dim
;
// [n, 0]
CAST_DTYPE
val
=
inputs
[
offset
];
// default = inputs(n, 0)
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
CAST_DTYPE
next
=
inputs
[
offset
+
d
];
if
(
next
>
val
)
{
val
=
next
;
}
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shared
[
threadIdx
.
x
+
stride
]
>
shared
[
threadIdx
.
x
])
{
shared
[
threadIdx
.
x
]
=
shared
[
threadIdx
.
x
+
stride
];
val
=
shared
[
threadIdx
.
x
];
}
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down
(
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shf
>
val
)
{
val
=
shf
;
}
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
val
;
}
}
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceLogSumExpGivenMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
// in: max -> out: logsum
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
CAST_DTYPE
max
=
outputs
[
blockIdx
.
x
];
CAST_DTYPE
val
=
0
;
int
offset
=
blockIdx
.
x
*
dim
;
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
val
=
val
+
std
::
exp
(
CAST_DTYPE
(
inputs
[
offset
+
d
])
-
max
);
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
shared
[
threadIdx
.
x
]
+
shared
[
threadIdx
.
x
+
stride
];
shared
[
threadIdx
.
x
]
=
val
;
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down
(
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
val
+
shf
;
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
max
+
std
::
log
(
val
);
}
}
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_ROCM
torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh
0 → 100644
View file @
d946a7c8
#pragma once
#ifdef USE_ROCM
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/dcu/kernels.h>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeLogProbs
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
CAST_DTYPE
*
logProbs
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
Indexer3D
indexer
(
maxT
,
maxU
);
int
idx
=
indexer
(
bTgt
,
t
,
u
);
// skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u).
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
])
-
denominators
[
idx
];
if
(
u
<
U
-
1
)
{
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int
target
=
targets
[
Indexer2D
(
maxU
-
1
)(
bTgt
,
u
)];
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
target
])
-
denominators
[
idx
];
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeAlphas
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
+
1
;
const
int
u
=
blockIdx
.
y
+
1
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
int
*
counter
=
alpha_counters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
1
&&
u
==
1
)
{
alphas
[
idxr
(
bTgt
,
0
,
0
)]
=
0
;
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
1
&&
u
<
U
)
{
// alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit().
alphas
[
idxr
(
bTgt
,
0
,
u
)]
=
alphas
[
idxr
(
bTgt
,
0
,
u
-
1
)]
+
logProbs
[(
idxr
(
bTgt
,
0
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
<
T
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
0
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up
(
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
val
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
0
)];
alphas
[
idxr
(
bTgt
,
t
,
0
)]
=
skip_prob
+
val
;
}
if
(
t
<
T
&&
u
<
U
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
alphas
[
idxr
(
bTgt
,
t
,
u
-
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up
(
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
alphas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
T
-
2
-
blockIdx
.
x
*
blockDim
.
x
-
threadIdx
.
x
;
const
int
u
=
U
-
2
-
blockIdx
.
y
;
if
(
t
<
0
||
u
<
0
)
{
// out of boundary.
return
;
}
int
*
counter
=
betaCounters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
T
-
2
&&
u
==
U
-
2
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
U
-
1
)]
=
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
T
-
2
&&
u
>=
0
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
u
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
,
u
+
1
)]
+
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up
(
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
U
-
1
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
-
blockIdx
.
x
*
blockDim
.
x
,
U
-
1
)]
+
skip_prob
;
}
if
(
t
>=
0
&&
u
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
betas
[
idxr
(
bTgt
,
t
+
threadIdx
.
x
+
1
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
betas
[
idxr
(
bTgt
,
t
,
u
+
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up
(
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
if
(
t
==
0
&&
u
==
0
)
{
// use -beta(0, 0) as cost.
costs
[
bTgt
]
=
DTYPE
(
-
out
);
}
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
warpSize
=
0
,
int
numWarps
=
0
,
int
H
=
1
)
{
assert
(
threadIdx
.
y
==
0
||
threadIdx
.
y
==
1
);
if
(
threadIdx
.
y
==
0
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
alpha_counters
,
/*alphas=*/
alphas
,
H
);
}
else
{
// threadIdx.y == 1
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*betaCounters=*/
betaCounters
,
/*beta=*/
betas
,
/*costs=*/
costs
,
H
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeGradients
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
CAST_DTYPE
clamp
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
int
H
=
1
)
{
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
ComputeGradientsElement
(
bTgt
,
t
,
u
,
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
clamp
,
logits
,
targets
,
srcLengths
,
tgtLengths
,
denominators
,
alphas
,
betas
,
gradients
,
H
);
}
// This is a __global__ wrapper around ComputeAlphas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
alpha_counters
,
alphas
,
H
);
}
// This is a __global__ wrapper around ComputeBetas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeBetasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
betaCounters
,
betas
,
costs
,
H
);
}
// #undef LOG_PROBS_SKIP_IDX
// #undef LOG_PROBS_EMIT_IDX
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_ROCM
torchaudio/csrc/rnnt/dcu/gpu_transducer.h
0 → 100644
View file @
d946a7c8
#pragma once
#ifdef USE_ROCM
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/dcu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/dcu/gpu_kernels.cuh>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline
void
gpuAssert
(
hipError_t
code
,
const
char
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
hipSuccess
)
{
fprintf
(
stderr
,
"
\n
GPUassert: %s %s %d
\n
"
,
hipGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
exit
(
code
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
LogSumExp2D
(
hipStream_t
stream
,
int
N
,
int
D
,
const
DTYPE
*
logits
,
// [N, D]
CAST_DTYPE
*
outputs
)
{
{
// compute max among D.
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED
;
}
}
{
// compute log(sum(exp(d_i - max)))
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceLogSumExpGivenMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED
;
}
}
return
SUCCESS
;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
Compute
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
gradients
=
nullptr
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
hipStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
const
CAST_DTYPE
clamp
=
options
.
clamp_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas, betas and costs.
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3
thread_dims
(
WARP_SIZE
,
2
);
ComputeAlphasBetasCosts
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*beta_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*costs=*/
costs
,
/*warp_size=*/
WARP_SIZE
,
/*num_warps=*/
num_warps
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
if
(
gradients
!=
nullptr
)
{
// compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int
num_blocks
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_blocks
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeGradients
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*clamp=*/
clamp
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*gradients=*/
gradients
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_GRADIENTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeAlphas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
alphas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
hipStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeAlphasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
alphas
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeBetas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
betas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
hipStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute betas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeBetasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
betas
,
costs
,
H
);
if
(
hipGetLastError
()
!=
hipSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_ROCM
torchaudio/csrc/rnnt/dcu/half.cuh
0 → 100644
View file @
d946a7c8
#pragma once
#ifdef USE_C10_HALF
#include "c10/util/Half.h"
#endif // USE_C10_HALF
#include <torchaudio/csrc/rnnt/macros.h>
namespace
torchaudio
{
namespace
rnnt
{
struct
alignas
(
sizeof
(
__half
))
Half
{
__half
x
;
HOST_AND_DEVICE
Half
()
=
default
;
FORCE_INLINE
HOST_AND_DEVICE
Half
(
float
f
)
{
x
=
__float2half_rn
(
f
);
if
(
isinf
(
__half2float
(
x
)))
{
x
=
__float2half_rz
(
f
);
// round toward 0.
}
}
FORCE_INLINE
HOST_AND_DEVICE
operator
float
()
const
{
return
__half2float
(
x
);
}
FORCE_INLINE
HOST_AND_DEVICE
Half
(
__half
f
)
{
x
=
f
;
}
FORCE_INLINE
HOST_AND_DEVICE
operator
__half
()
const
{
return
x
;
}
};
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/kernel_utils.h
0 → 100644
View file @
d946a7c8
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
inline
HOST_AND_DEVICE
bool
in_range
(
int
start
,
int
end
,
// inclusive
int
val
)
{
return
start
<=
val
&&
val
<=
end
;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct
Indexer2D
{
const
int
&
size2_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer2D
(
const
int
&
size2
)
:
size2_
(
size2
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
)
{
return
index1
*
size2_
+
index2
;
}
};
struct
Indexer3D
{
const
int
&
size2_
;
const
int
&
size3_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer3D
(
const
int
&
size2
,
const
int
&
size3
)
:
size2_
(
size2
),
size3_
(
size3
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
)
{
return
(
index1
*
size2_
+
index2
)
*
size3_
+
index3
;
}
};
struct
Indexer4D
{
const
int
&
size2_
;
const
int
&
size3_
;
const
int
&
size4_
;
HOST_AND_DEVICE
Indexer4D
(
const
int
&
size2
,
const
int
&
size3
,
const
int
&
size4
)
:
size2_
(
size2
),
size3_
(
size3
),
size4_
(
size4
)
{}
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
,
int
index4
)
{
return
((
index1
*
size2_
+
index2
)
*
size3_
+
index3
)
*
size4_
+
index4
;
}
};
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/kernels.h
0 → 100644
View file @
d946a7c8
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/dcu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/dcu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
HOST_AND_DEVICE
void
ComputeGradientsElement
(
int
bTgt
,
int
t
,
int
u
,
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
CAST_DTYPE
clamp
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
if
(
gradients
==
logits
&&
t
<
maxT
&&
u
<
maxU
)
{
// gradients and logits are pointing to the same memory location
Indexer3D
idxr3
(
maxT
,
maxU
);
int
idx_b_t_u_zero
=
idxr3
(
bTgt
,
t
,
u
);
if
(
idx_b_t_u_zero
!=
-
1
)
{
int
start
=
idx_b_t_u_zero
*
D
;
for
(
int
b_t_u_d
=
start
;
b_t_u_d
<
start
+
D
;
++
b_t_u_d
)
{
gradients
[
b_t_u_d
]
=
0
;
}
}
}
return
;
}
int
costIdx
=
bTgt
*
maxT
*
maxU
;
CAST_DTYPE
cost
=
-
(
betas
[
costIdx
]);
Indexer2D
idxr2
(
maxU
-
1
);
int
idx_b_t_u
,
idx_b_t_up1
,
idx_b_tp1_u
;
Indexer3D
idxr3
(
maxT
,
maxU
);
idx_b_t_u
=
idxr3
(
bTgt
,
t
,
u
);
idx_b_t_up1
=
idxr3
(
bTgt
,
t
,
u
+
1
);
idx_b_tp1_u
=
idxr3
(
bTgt
,
t
+
1
,
u
);
if
(
idx_b_t_u
==
-
1
)
{
return
;
}
if
(
isinf
(
cost
)
||
isnan
(
cost
))
{
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
gradients
[
b_t_u_d
]
=
0
;
}
return
;
}
CAST_DTYPE
c
=
alphas
[
idx_b_t_u
]
+
cost
-
denominators
[
idx_b_t_u
];
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
int
b_t_u_d
=
idx_b_t_u
*
D
+
d
;
CAST_DTYPE
g
=
CAST_DTYPE
(
logits
[
b_t_u_d
])
+
c
;
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
// last blank transition.
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
])
-
std
::
exp
(
g
);
}
else
if
(
t
<
T
-
1
&&
d
==
blank
)
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_tp1_u
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_tp1_u
]);
}
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
idxr2
(
bTgt
,
u
)])
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
if
(
idx_b_t_up1
!=
-
1
)
{
gradients
[
b_t_u_d
]
=
gradients
[
b_t_u_d
]
-
std
::
exp
(
g
+
betas
[
idx_b_t_up1
]);
}
}
else
{
gradients
[
b_t_u_d
]
=
std
::
exp
(
g
+
betas
[
idx_b_t_u
]);
}
if
(
clamp
>
0
)
{
auto
g
=
CAST_DTYPE
(
gradients
[
b_t_u_d
]);
gradients
[
b_t_u_d
]
=
math
::
min
(
g
,
clamp
);
gradients
[
b_t_u_d
]
=
math
::
max
(
g
,
-
clamp
);
}
}
}
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/dcu/math.cuh
0 → 100644
View file @
d946a7c8
#pragma once
#ifdef USE_ROCM
#include <cmath>
#endif // USE_ROCM
#include <torchaudio/csrc/rnnt/dcu/half.cuh>
namespace
torchaudio
{
namespace
rnnt
{
namespace
math
{
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
max
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
x
;
else
return
y
;
}
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
min
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
y
;
else
return
x
;
}
// log_sum_exp
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
lse
(
DTYPE
x
,
DTYPE
y
);
template
<
>
FORCE_INLINE
HOST_AND_DEVICE
float
lse
(
float
x
,
float
y
)
{
if
(
y
>
x
)
{
return
y
+
log1pf
(
expf
(
x
-
y
));
}
else
{
return
x
+
log1pf
(
expf
(
y
-
x
));
}
}
}
// namespace math
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/macros.h
View file @
d946a7c8
#pragma once
#ifdef
USE_CUDA
#if
def
ined(
USE_CUDA
)
#define WARP_SIZE 32
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
...
...
@@ -8,6 +8,14 @@
#define FORCE_INLINE __forceinline__
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#elif defined(USE_ROCM)
#define WARP_SIZE 64
#define MAX_THREADS_PER_BLOCK 1024
#define REDUCE_THREADS 256
#define HOST_AND_DEVICE __host__ __device__
#define FORCE_INLINE __forceinline__
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#else
#define HOST_AND_DEVICE
#define FORCE_INLINE inline
...
...
torchaudio/csrc/rnnt/options.h
View file @
d946a7c8
...
...
@@ -6,6 +6,10 @@
#include <cuda_runtime.h>
#endif // USE_CUDA
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#endif // USE_ROCM
#include <torchaudio/csrc/rnnt/macros.h>
#include <torchaudio/csrc/rnnt/types.h>
...
...
@@ -18,6 +22,10 @@ typedef struct Options {
#ifdef USE_CUDA
// the stream to launch kernels in when using GPU.
cudaStream_t
stream_
;
#endif
#ifdef USE_ROCM
// the stream to launch kernels in when using GPU.
hipStream_t
stream_
;
#endif
// The maximum number of threads that can be used.
int
numThreads_
;
...
...
torchaudio/csrc/rnnt/workspace.h
View file @
d946a7c8
...
...
@@ -27,7 +27,7 @@ class DtypeWorkspace {
~
DtypeWorkspace
()
{}
static
int
ComputeSizeFromOptions
(
const
Options
&
options
)
{
TORCH_
CHECK_NE
(
options
.
device_
,
UNDEFINED
);
CHECK_NE
(
options
.
device_
,
UNDEFINED
);
return
ComputeSizeForDenominators
(
options
)
+
ComputeSizeForLogProbs
(
options
)
+
ComputeSizeForAlphas
(
options
)
+
ComputeSizeForBetas
(
options
);
...
...
@@ -36,7 +36,7 @@ class DtypeWorkspace {
void
Free
();
void
Reset
(
const
Options
&
options
,
DTYPE
*
data
,
int
size
)
{
int
needed_size
=
ComputeSizeFromOptions
(
options
);
TORCH_
CHECK_LE
(
needed_size
,
size
);
CHECK_LE
(
needed_size
,
size
);
options_
=
options
;
data_
=
data
;
size_
=
size
;
...
...
@@ -98,7 +98,7 @@ class IntWorkspace {
void
Reset
(
const
Options
&
options
,
int
*
data
,
int
size
)
{
int
needed_size
=
ComputeSizeFromOptions
(
options
);
TORCH_
CHECK_LE
(
needed_size
,
size
);
CHECK_LE
(
needed_size
,
size
);
options_
=
options
;
data_
=
data
;
size_
=
size
;
...
...
@@ -109,11 +109,11 @@ class IntWorkspace {
}
int
*
GetPointerToAlphaCounters
()
const
{
TORCH_
CHECK_EQ
(
options_
.
device_
,
GPU
);
CHECK_EQ
(
options_
.
device_
,
GPU
);
return
data_
;
}
int
*
GetPointerToBetaCounters
()
const
{
TORCH_
CHECK_EQ
(
options_
.
device_
,
GPU
);
CHECK_EQ
(
options_
.
device_
,
GPU
);
return
GetPointerToAlphaCounters
()
+
ComputeSizeForAlphaCounters
(
options_
);
}
...
...
@@ -131,10 +131,22 @@ class IntWorkspace {
ComputeSizeForBetaCounters
(
options_
)
*
sizeof
(
int
));
}
#endif // USE_CUDA
#ifdef USE_ROCM
if
(
data_
!=
nullptr
&&
options_
.
device_
==
GPU
)
{
hipMemset
(
GetPointerToAlphaCounters
(),
0
,
ComputeSizeForAlphaCounters
(
options_
)
*
sizeof
(
int
));
hipMemset
(
GetPointerToBetaCounters
(),
0
,
ComputeSizeForBetaCounters
(
options_
)
*
sizeof
(
int
));
}
#endif // USE_ROCM
}
static
int
ComputeSizeForAlphaCounters
(
const
Options
&
options
)
{
// B * U
#ifdef
USE_CUDA
#if
def
ined(
USE_CUDA
) || defined(USE_ROCM)
if
(
options
.
device_
==
GPU
)
{
return
options
.
BU
();
}
else
{
...
...
@@ -145,7 +157,7 @@ class IntWorkspace {
#endif // USE_CUDA
}
static
int
ComputeSizeForBetaCounters
(
const
Options
&
options
)
{
// B * U
#ifdef
USE_CUDA
#if
def
ined(
USE_CUDA
) || defined(USE_ROCM)
if
(
options
.
device_
==
GPU
)
{
return
options
.
BU
();
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment