Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
9dcc7a15
Commit
9dcc7a15
authored
Apr 25, 2022
by
flyingdown
Browse files
init v0.10.0
parent
db2b0b79
Changes
416
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2440 additions
and
0 deletions
+2440
-0
torchaudio/csrc/pybind/sox/utils.cpp
torchaudio/csrc/pybind/sox/utils.cpp
+31
-0
torchaudio/csrc/pybind/sox/utils.h
torchaudio/csrc/pybind/sox/utils.h
+12
-0
torchaudio/csrc/rnnt/autograd.cpp
torchaudio/csrc/rnnt/autograd.cpp
+56
-0
torchaudio/csrc/rnnt/compute.cpp
torchaudio/csrc/rnnt/compute.cpp
+25
-0
torchaudio/csrc/rnnt/compute.h
torchaudio/csrc/rnnt/compute.h
+11
-0
torchaudio/csrc/rnnt/compute_alphas.cpp
torchaudio/csrc/rnnt/compute_alphas.cpp
+11
-0
torchaudio/csrc/rnnt/compute_betas.cpp
torchaudio/csrc/rnnt/compute_betas.cpp
+11
-0
torchaudio/csrc/rnnt/cpu/compute.cpp
torchaudio/csrc/rnnt/cpu/compute.cpp
+148
-0
torchaudio/csrc/rnnt/cpu/compute_alphas.cpp
torchaudio/csrc/rnnt/cpu/compute_alphas.cpp
+70
-0
torchaudio/csrc/rnnt/cpu/compute_betas.cpp
torchaudio/csrc/rnnt/cpu/compute_betas.cpp
+75
-0
torchaudio/csrc/rnnt/cpu/cpu_kernels.h
torchaudio/csrc/rnnt/cpu/cpu_kernels.h
+498
-0
torchaudio/csrc/rnnt/cpu/cpu_transducer.h
torchaudio/csrc/rnnt/cpu/cpu_transducer.h
+184
-0
torchaudio/csrc/rnnt/cpu/kernel_utils.h
torchaudio/csrc/rnnt/cpu/kernel_utils.h
+66
-0
torchaudio/csrc/rnnt/cpu/math.h
torchaudio/csrc/rnnt/cpu/math.h
+42
-0
torchaudio/csrc/rnnt/gpu/compute.cu
torchaudio/csrc/rnnt/gpu/compute.cu
+151
-0
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
+73
-0
torchaudio/csrc/rnnt/gpu/compute_betas.cu
torchaudio/csrc/rnnt/gpu/compute_betas.cu
+78
-0
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
+98
-0
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
+409
-0
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
+391
-0
No files found.
Too many changes to show.
To preserve performance only
416 of 416+
files are displayed.
Plain diff
Email patch
torchaudio/csrc/pybind/sox/utils.cpp
0 → 100644
View file @
9dcc7a15
#include <torchaudio/csrc/pybind/sox/utils.h>
namespace
torchaudio
::
sox_utils
{
auto
read_fileobj
(
py
::
object
*
fileobj
,
const
uint64_t
size
,
char
*
buffer
)
->
uint64_t
{
uint64_t
num_read
=
0
;
while
(
num_read
<
size
)
{
auto
request
=
size
-
num_read
;
auto
chunk
=
static_cast
<
std
::
string
>
(
static_cast
<
py
::
bytes
>
(
fileobj
->
attr
(
"read"
)(
request
)));
auto
chunk_len
=
chunk
.
length
();
if
(
chunk_len
==
0
)
{
break
;
}
if
(
chunk_len
>
request
)
{
std
::
ostringstream
message
;
message
<<
"Requested up to "
<<
request
<<
" bytes but, "
<<
"received "
<<
chunk_len
<<
" bytes. "
<<
"The given object does not confirm to read protocol of file object."
;
throw
std
::
runtime_error
(
message
.
str
());
}
memcpy
(
buffer
,
chunk
.
data
(),
chunk_len
);
buffer
+=
chunk_len
;
num_read
+=
chunk_len
;
}
return
num_read
;
}
}
// namespace torchaudio::sox_utils
torchaudio/csrc/pybind/sox/utils.h
0 → 100644
View file @
9dcc7a15
#ifndef TORCHAUDIO_PYBIND_SOX_UTILS_H
#define TORCHAUDIO_PYBIND_SOX_UTILS_H
#include <torch/extension.h>
namespace
torchaudio
::
sox_utils
{
auto
read_fileobj
(
py
::
object
*
fileobj
,
uint64_t
size
,
char
*
buffer
)
->
uint64_t
;
}
// namespace torchaudio::sox_utils
#endif
torchaudio/csrc/rnnt/autograd.cpp
0 → 100644
View file @
9dcc7a15
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
namespace
torchaudio
{
namespace
rnnt
{
class
RNNTLossFunction
:
public
torch
::
autograd
::
Function
<
RNNTLossFunction
>
{
public:
static
torch
::
autograd
::
tensor_list
forward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
torch
::
Tensor
undef
;
auto
result
=
rnnt_loss
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
);
auto
costs
=
std
::
get
<
0
>
(
result
);
auto
grads
=
std
::
get
<
1
>
(
result
).
value_or
(
undef
);
ctx
->
save_for_backward
({
grads
});
return
{
costs
,
grads
};
}
static
torch
::
autograd
::
tensor_list
backward
(
torch
::
autograd
::
AutogradContext
*
ctx
,
torch
::
autograd
::
tensor_list
grad_outputs
)
{
auto
saved
=
ctx
->
get_saved_variables
();
auto
grad
=
saved
[
0
];
auto
grad_out
=
grad_outputs
[
0
].
view
({
-
1
,
1
,
1
,
1
});
auto
result
=
grad
*
grad_out
;
torch
::
Tensor
undef
;
return
{
result
,
undef
,
undef
,
undef
,
undef
,
undef
,
undef
,
undef
};
}
};
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss_autograd
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
at
::
AutoDispatchBelowADInplaceOrView
guard
;
auto
results
=
RNNTLossFunction
::
apply
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
);
return
std
::
make_tuple
(
results
[
0
],
results
[
1
]);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
Autograd
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
rnnt_loss_autograd
);
}
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/compute.cpp
0 → 100644
View file @
9dcc7a15
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/compute.h>
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
static
auto
op
=
torch
::
Dispatcher
::
singleton
()
.
findSchemaOrThrow
(
"torchaudio::rnnt_loss"
,
""
)
.
typed
<
decltype
(
rnnt_loss
)
>
();
return
op
.
call
(
logits
,
targets
,
logit_lengths
,
target_lengths
,
blank
,
clamp
);
}
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
m
.
def
(
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> (Tensor, Tensor?)"
);
}
torchaudio/csrc/rnnt/compute.h
0 → 100644
View file @
9dcc7a15
#pragma once
#include <torch/script.h>
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
rnnt_loss
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
);
torchaudio/csrc/rnnt/compute_alphas.cpp
0 → 100644
View file @
9dcc7a15
#include <torch/script.h>
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
m
.
def
(
"rnnt_loss_alphas(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> Tensor"
);
}
torchaudio/csrc/rnnt/compute_betas.cpp
0 → 100644
View file @
9dcc7a15
#include <torch/script.h>
TORCH_LIBRARY_FRAGMENT
(
torchaudio
,
m
)
{
m
.
def
(
"rnnt_loss_betas(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp) -> Tensor"
);
}
torchaudio/csrc/rnnt/cpu/compute.cpp
0 → 100644
View file @
9dcc7a15
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
cpu
{
// Entry point into RNNT Loss
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
logit_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
target_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
logit_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
target_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
logit_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
target_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
logit_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
target_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
logit_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
target_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
logit_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
target_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
target_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CPU
);
options
.
device_
=
CPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
options
.
batchSize_
*
options
.
nHypos_
,
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
torch
::
zeros_like
(
logits
);
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
switch
(
logits
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
{
Compute
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
gradients
->
data_ptr
<
float
>
());
break
;
}
case
torch
::
ScalarType
::
Half
:
{
Compute
<
/*DTYPE=*/
c10
::
Half
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
gradients
->
data_ptr
<
c10
::
Half
>
());
break
;
}
default:
{
break
;
}
};
return
std
::
make_tuple
(
costs
,
gradients
);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CPU
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
&
compute
);
}
}
// namespace cpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/cpu/compute_alphas.cpp
0 → 100644
View file @
9dcc7a15
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
cpu
{
torch
::
Tensor
compute_alphas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CPU
);
options
.
device_
=
CPU
;
torch
::
Tensor
alphas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
return
alphas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CPU
,
m
)
{
m
.
impl
(
"rnnt_loss_alphas"
,
&
compute_alphas
);
}
}
// namespace cpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/cpu/compute_betas.cpp
0 → 100644
View file @
9dcc7a15
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
cpu
{
torch
::
Tensor
compute_betas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CPU
);
options
.
device_
=
CPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
target_lengths
.
size
(
0
),
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
betas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*betas=*/
betas
.
data_ptr
<
float
>
());
return
betas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CPU
,
m
)
{
m
.
impl
(
"rnnt_loss_betas"
,
&
compute_betas
);
}
}
// namespace cpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/cpu/cpu_kernels.h
0 → 100644
View file @
9dcc7a15
#pragma once
#include <torchaudio/csrc/rnnt/cpu/math.h>
#include <torchaudio/csrc/rnnt/options.h>
#include <torchaudio/csrc/rnnt/types.h>
#include <cstring>
#include <limits>
#include <vector>
namespace
torchaudio
{
namespace
rnnt
{
namespace
cpu
{
template
<
typename
DTYPE
>
struct
LogProbs
{
DTYPE
skip_
;
// blank.
DTYPE
emit_
;
// target.
LogProbs
(
DTYPE
skip
,
DTYPE
emit
)
:
skip_
(
skip
),
emit_
(
emit
)
{}
DTYPE
&
skip
()
{
return
skip_
;
}
DTYPE
&
emit
()
{
return
emit_
;
}
const
DTYPE
&
skip
()
const
{
return
skip_
;
}
const
DTYPE
&
emit
()
const
{
return
emit_
;
}
};
// TensorView: view a block of allocated memory as a tensor.
template
<
typename
DTYPE
>
class
TensorView
{
public:
TensorView
(
const
std
::
vector
<
int
>&
dims
,
DTYPE
*
data
)
:
dims_
(
dims
),
data_
(
data
)
{
strides_
.
resize
(
dims
.
size
());
strides_
.
back
()
=
1
;
for
(
int
i
=
dims
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
strides_
[
i
]
=
strides_
[
i
+
1
]
*
dims
[
i
+
1
];
}
}
DTYPE
&
operator
()(
const
std
::
vector
<
int
>&
indices
)
{
CHECK_EQ
(
indices
.
size
(),
dims_
.
size
());
int
index
=
indices
.
back
();
for
(
int
i
=
indices
.
size
()
-
2
;
i
>=
0
;
--
i
)
{
index
+=
indices
[
i
]
*
strides_
[
i
];
}
return
data_
[
index
];
}
void
SetZero
()
{
int
size
=
dims_
[
0
]
*
strides_
[
0
];
std
::
memset
(
data_
,
0
,
sizeof
(
DTYPE
)
*
size
);
}
private:
std
::
vector
<
int
>
dims_
;
std
::
vector
<
int
>
strides_
;
DTYPE
*
data_
;
};
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
LogSumExp2D
(
int
N
,
int
D
,
const
DTYPE
*
logits
,
CAST_DTYPE
*
outputs
)
{
for
(
int
i
=
0
;
i
<
N
*
D
;
i
+=
D
)
{
CAST_DTYPE
max
=
logits
[
i
];
for
(
int
j
=
1
;
j
<
D
;
++
j
)
{
max
=
std
::
max
(
max
,
CAST_DTYPE
(
logits
[
i
+
j
]));
}
CAST_DTYPE
sum
=
0
;
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
sum
=
sum
+
std
::
exp
(
CAST_DTYPE
(
logits
[
i
+
j
])
-
max
);
}
outputs
[
i
/
D
]
=
max
+
std
::
log
(
sum
);
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
void
ComputeLogProbsOneSequence
(
const
Options
&
options
,
TensorView
<
const
DTYPE
>&
logits
,
const
int
*
targets
,
int
srcLen
,
int
tgtLen
,
TensorView
<
const
CAST_DTYPE
>&
denom
,
TensorView
<
LogProbs
<
CAST_DTYPE
>>&
logProbs
)
{
const
int
&
T
=
srcLen
;
const
int
&
U
=
tgtLen
;
const
int
&
blank
=
options
.
blank_
;
for
(
int
t
=
0
;
t
<
T
;
++
t
)
{
for
(
int
u
=
0
;
u
<
U
;
++
u
)
{
if
(
u
<
U
-
1
)
{
logProbs
({
t
,
u
}).
emit
()
=
CAST_DTYPE
(
logits
({
t
,
u
,
targets
[
u
]}))
-
denom
({
t
,
u
});
}
logProbs
({
t
,
u
}).
skip
()
=
CAST_DTYPE
(
logits
({
t
,
u
,
blank
}))
-
denom
({
t
,
u
});
}
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeLogProbs
(
const
Options
&
options
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
CAST_DTYPE
*
logProbs
)
{
std
::
vector
<
TensorView
<
const
DTYPE
>>
seqLogits
;
std
::
vector
<
const
int
*>
seqTargets
;
std
::
vector
<
TensorView
<
const
CAST_DTYPE
>>
seqDenoms
;
std
::
vector
<
TensorView
<
LogProbs
<
CAST_DTYPE
>>>
seqlogProbs
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
seqLogits
.
push_back
(
TensorView
<
const
DTYPE
>
({
maxT
,
maxU
,
D
},
logits
+
b
*
maxT
*
maxU
*
D
));
seqTargets
.
push_back
(
targets
+
b
*
(
maxU
-
1
));
seqDenoms
.
push_back
(
TensorView
<
const
CAST_DTYPE
>
(
{
maxT
,
maxU
},
denominators
+
b
*
maxT
*
maxU
));
seqlogProbs
.
push_back
(
TensorView
<
LogProbs
<
CAST_DTYPE
>>
(
{
maxT
,
maxU
},
reinterpret_cast
<
LogProbs
<
CAST_DTYPE
>*>
(
logProbs
)
+
b
*
maxT
*
maxU
));
}
//#pragma omp parallel for
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
// use max 2 * B threads.
ComputeLogProbsOneSequence
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*logits=*/
seqLogits
[
b
],
/*targets=*/
seqTargets
[
b
],
/*srcLen=*/
srcLengths
[
b
],
/*tgtLen=*/
tgtLengths
[
b
]
+
1
,
// with prepended blank.
/*denom=*/
seqDenoms
[
b
],
/*logProbs=*/
seqlogProbs
[
b
]);
}
return
SUCCESS
;
}
template
<
typename
DTYPE
>
DTYPE
ComputeAlphaOneSequence
(
const
Options
&
options
,
TensorView
<
const
LogProbs
<
DTYPE
>>&
logProbs
,
int
srcLen
,
int
tgtLen
,
TensorView
<
DTYPE
>&
alpha
)
{
const
int
&
T
=
srcLen
;
const
int
&
U
=
tgtLen
;
alpha
({
0
,
0
})
=
DTYPE
(
0
);
for
(
int
t
=
1
;
t
<
T
;
++
t
)
{
// u == 0.
alpha
({
t
,
0
})
=
alpha
({
t
-
1
,
0
})
+
logProbs
({
t
-
1
,
0
}).
skip
();
}
for
(
int
u
=
1
;
u
<
U
;
++
u
)
{
// t == 0.
alpha
({
0
,
u
})
=
alpha
({
0
,
u
-
1
})
+
logProbs
({
0
,
u
-
1
}).
emit
();
}
for
(
int
t
=
1
;
t
<
T
;
++
t
)
{
for
(
int
u
=
1
;
u
<
U
;
++
u
)
{
alpha
({
t
,
u
})
=
math
::
lse
(
alpha
({
t
-
1
,
u
})
+
logProbs
({
t
-
1
,
u
}).
skip
(),
alpha
({
t
,
u
-
1
})
+
logProbs
({
t
,
u
-
1
}).
emit
());
}
}
DTYPE
forward_score
=
alpha
({
T
-
1
,
U
-
1
})
+
logProbs
({
T
-
1
,
U
-
1
}).
skip
();
return
forward_score
;
}
template
<
typename
DTYPE
>
DTYPE
ComputeBetaOneSequence
(
const
Options
&
options
,
TensorView
<
const
LogProbs
<
DTYPE
>>&
logProbs
,
int
srcLen
,
int
tgtLen
,
TensorView
<
DTYPE
>&
beta
)
{
const
int
&
T
=
srcLen
;
const
int
&
U
=
tgtLen
;
beta
({
T
-
1
,
U
-
1
})
=
logProbs
({
T
-
1
,
U
-
1
}).
skip
();
for
(
int
t
=
T
-
2
;
t
>=
0
;
--
t
)
{
// u == U - 1.
beta
({
t
,
U
-
1
})
=
beta
({
t
+
1
,
U
-
1
})
+
logProbs
({
t
,
U
-
1
}).
skip
();
}
for
(
int
u
=
U
-
2
;
u
>=
0
;
--
u
)
{
// t == T - 1.
beta
({
T
-
1
,
u
})
=
beta
({
T
-
1
,
u
+
1
})
+
logProbs
({
T
-
1
,
u
}).
emit
();
}
for
(
int
t
=
T
-
2
;
t
>=
0
;
--
t
)
{
for
(
int
u
=
U
-
2
;
u
>=
0
;
--
u
)
{
beta
({
t
,
u
})
=
math
::
lse
(
beta
({
t
+
1
,
u
})
+
logProbs
({
t
,
u
}).
skip
(),
beta
({
t
,
u
+
1
})
+
logProbs
({
t
,
u
}).
emit
());
}
}
DTYPE
backward_score
=
beta
({
0
,
0
});
return
backward_score
;
}
template
<
typename
DTYPE
>
DTYPE
ComputeAlphaOrBetaOneSequence
(
int
thread
,
const
Options
&
options
,
TensorView
<
const
LogProbs
<
DTYPE
>>&
logProbs
,
int
srcLen
,
int
tgtLen
,
TensorView
<
DTYPE
>&
alpha
,
TensorView
<
DTYPE
>&
beta
)
{
if
(
thread
&
1
)
{
return
ComputeAlphaOneSequence
<
DTYPE
>
(
/*options=*/
options
,
/*logProbs=*/
logProbs
,
/*srcLen=*/
srcLen
,
/*tgtLen=*/
tgtLen
,
/*alpha=*/
alpha
);
}
else
{
return
ComputeBetaOneSequence
<
DTYPE
>
(
/*options=*/
options
,
/*logProbs=*/
logProbs
,
/*srcLen=*/
srcLen
,
/*tgtLen=*/
tgtLen
,
/*beta=*/
beta
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
void
ComputeAlphasBetas
(
const
Options
&
options
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
CAST_DTYPE
*
alphas
,
CAST_DTYPE
*
betas
,
DTYPE
*
costs
)
{
std
::
vector
<
TensorView
<
const
LogProbs
<
CAST_DTYPE
>>>
seqlogProbs
;
std
::
vector
<
TensorView
<
CAST_DTYPE
>>
seq_alphas
;
std
::
vector
<
TensorView
<
CAST_DTYPE
>>
seq_betas
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
seqlogProbs
.
push_back
(
TensorView
<
const
LogProbs
<
CAST_DTYPE
>>
(
{
maxT
,
maxU
},
reinterpret_cast
<
LogProbs
<
CAST_DTYPE
>*>
(
const_cast
<
CAST_DTYPE
*>
(
logProbs
))
+
b
*
maxT
*
maxU
));
seq_alphas
.
push_back
(
TensorView
<
CAST_DTYPE
>
({
maxT
,
maxU
},
alphas
+
b
*
maxT
*
maxU
));
seq_betas
.
push_back
(
TensorView
<
CAST_DTYPE
>
({
maxT
,
maxU
},
betas
+
b
*
maxT
*
maxU
));
}
std
::
vector
<
CAST_DTYPE
>
scores
(
B
<<
1
);
//#pragma omp parallel for
for
(
int
t
=
0
;
t
<
(
B
<<
1
);
++
t
)
{
// use max 2 * B threads.
int
i
=
(
t
>>
1
);
scores
[
t
]
=
ComputeAlphaOrBetaOneSequence
<
CAST_DTYPE
>
(
/*thread=*/
t
,
/*options=*/
options
,
/*logProbs=*/
seqlogProbs
[
i
],
/*srcLen=*/
srcLengths
[
i
],
/*tgtLen=*/
tgtLengths
[
i
]
+
1
,
// with prepended blank.
/*alpha=*/
seq_alphas
[
i
],
/*beta=*/
seq_betas
[
i
]);
}
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
costs
[
b
]
=
-
scores
[
b
<<
1
];
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
void
ComputeGradientsOneSequence
(
const
Options
&
options
,
TensorView
<
const
DTYPE
>&
logits
,
const
int
*
targets
,
int
srcLen
,
int
tgtLen
,
TensorView
<
const
CAST_DTYPE
>&
denom
,
TensorView
<
const
CAST_DTYPE
>&
alpha
,
TensorView
<
const
CAST_DTYPE
>&
beta
,
TensorView
<
DTYPE
>&
gradients
)
{
// don't set gradients to zero to here as gradients might reuse memory from
// logits
const
int
&
T
=
srcLen
;
const
int
&
U
=
tgtLen
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
const
CAST_DTYPE
clamp
=
options
.
clamp_
;
CAST_DTYPE
cost
=
-
beta
({
0
,
0
});
// Note - below gradient is different from numpy_transducer, since we
// compute log_softmax more efficiently within the loss, to save memory The
// details of the below implementation / equations can be found in Sec 3.2
// (function merging) in below paper:
// https://www.microsoft.com/en-us/research/uploads/prod/2019/10/RNNT.pdf
for
(
int
t
=
0
;
t
<
T
;
++
t
)
{
for
(
int
u
=
0
;
u
<
U
;
++
u
)
{
CAST_DTYPE
c
=
alpha
({
t
,
u
})
+
cost
-
denom
({
t
,
u
});
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
CAST_DTYPE
g
=
CAST_DTYPE
(
logits
({
t
,
u
,
d
}))
+
c
;
if
(
d
==
blank
&&
t
==
T
-
1
&&
u
==
U
-
1
)
{
// last blank transition.
gradients
({
t
,
u
,
d
})
=
std
::
exp
(
g
+
beta
({
t
,
u
}))
-
std
::
exp
(
g
);
}
else
if
(
d
==
blank
&&
t
<
T
-
1
)
{
gradients
({
t
,
u
,
d
})
=
std
::
exp
(
g
+
beta
({
t
,
u
}))
-
std
::
exp
(
g
+
beta
({
t
+
1
,
u
}));
}
else
if
(
u
<
U
-
1
&&
d
==
targets
[
u
])
{
gradients
({
t
,
u
,
d
})
=
std
::
exp
(
g
+
beta
({
t
,
u
}))
-
std
::
exp
(
g
+
beta
({
t
,
u
+
1
}));
}
else
{
gradients
({
t
,
u
,
d
})
=
std
::
exp
(
g
+
beta
({
t
,
u
}));
}
if
(
clamp
>
0
)
{
gradients
({
t
,
u
,
d
})
=
math
::
min
(
CAST_DTYPE
(
gradients
({
t
,
u
,
d
})),
clamp
);
gradients
({
t
,
u
,
d
})
=
math
::
max
(
CAST_DTYPE
(
gradients
({
t
,
u
,
d
})),
-
clamp
);
}
}
}
}
// zero out the rest of the gradients, necessary when reusing logits memory
// check the memory location to see if it's necessary
if
(
&
gradients
({
0
,
0
,
0
})
==
&
logits
({
0
,
0
,
0
}))
{
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
for
(
int
t
=
T
;
t
<
maxT
;
++
t
)
{
for
(
int
u
=
0
;
u
<
maxU
;
++
u
)
{
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
gradients
({
t
,
u
,
d
})
=
0.
;
}
}
}
for
(
int
t
=
0
;
t
<
T
;
++
t
)
{
for
(
int
u
=
U
;
u
<
maxU
;
++
u
)
{
for
(
int
d
=
0
;
d
<
D
;
++
d
)
{
gradients
({
t
,
u
,
d
})
=
0.
;
}
}
}
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
void
ComputeGradients
(
const
Options
&
options
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
)
{
std
::
vector
<
TensorView
<
const
DTYPE
>>
seqLogits
;
std
::
vector
<
const
int
*>
seqTargets
;
std
::
vector
<
TensorView
<
const
CAST_DTYPE
>>
seqDenoms
;
std
::
vector
<
TensorView
<
const
CAST_DTYPE
>>
seq_alphas
;
std
::
vector
<
TensorView
<
const
CAST_DTYPE
>>
seq_betas
;
std
::
vector
<
TensorView
<
DTYPE
>>
seq_gradients
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
seqLogits
.
push_back
(
TensorView
<
const
DTYPE
>
({
maxT
,
maxU
,
D
},
logits
+
b
*
maxT
*
maxU
*
D
));
seqTargets
.
push_back
(
targets
+
b
*
(
maxU
-
1
));
seqDenoms
.
push_back
(
TensorView
<
const
CAST_DTYPE
>
(
{
maxT
,
maxU
},
denominators
+
b
*
maxT
*
maxU
));
seq_alphas
.
push_back
(
TensorView
<
const
CAST_DTYPE
>
({
maxT
,
maxU
},
alphas
+
b
*
maxT
*
maxU
));
seq_betas
.
push_back
(
TensorView
<
const
CAST_DTYPE
>
({
maxT
,
maxU
},
betas
+
b
*
maxT
*
maxU
));
seq_gradients
.
push_back
(
TensorView
<
DTYPE
>
({
maxT
,
maxU
,
D
},
gradients
+
b
*
maxT
*
maxU
*
D
));
}
//#pragma omp parallel for
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
// use max 2 * B threads.
ComputeGradientsOneSequence
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*logits=*/
seqLogits
[
b
],
/*targets=*/
seqTargets
[
b
],
/*srcLen=*/
srcLengths
[
b
],
/*tgtLen=*/
tgtLengths
[
b
]
+
1
,
// with prepended blank.
/*denom=*/
seqDenoms
[
b
],
/*alpha=*/
seq_alphas
[
b
],
/*beta=*/
seq_betas
[
b
],
/*gradients=*/
seq_gradients
[
b
]);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
void
ComputeAlphas
(
const
Options
&
options
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
CAST_DTYPE
*
alphas
)
{
std
::
vector
<
TensorView
<
const
LogProbs
<
CAST_DTYPE
>>>
seqlogProbs
;
std
::
vector
<
TensorView
<
CAST_DTYPE
>>
seq_alphas
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
seqlogProbs
.
push_back
(
TensorView
<
const
LogProbs
<
CAST_DTYPE
>>
(
{
maxT
,
maxU
},
reinterpret_cast
<
LogProbs
<
CAST_DTYPE
>*>
(
const_cast
<
CAST_DTYPE
*>
(
logProbs
))
+
b
*
maxT
*
maxU
));
seq_alphas
.
push_back
(
TensorView
<
CAST_DTYPE
>
({
maxT
,
maxU
},
alphas
+
b
*
maxT
*
maxU
));
}
std
::
vector
<
CAST_DTYPE
>
scores
(
B
<<
1
);
//#pragma omp parallel for
for
(
int
i
=
0
;
i
<
B
;
++
i
)
{
// use max 2 * B threads.
ComputeAlphaOneSequence
<
DTYPE
>
(
options
,
/*logProbs=*/
seqlogProbs
[
i
],
/*srcLen=*/
srcLengths
[
i
],
/*tgtLen=*/
tgtLengths
[
i
]
+
1
,
// with prepended blank.
/*alpha=*/
seq_alphas
[
i
]);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
void
ComputeBetas
(
const
Options
&
options
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
CAST_DTYPE
*
costs
,
CAST_DTYPE
*
betas
)
{
std
::
vector
<
TensorView
<
const
LogProbs
<
CAST_DTYPE
>>>
seqlogProbs
;
std
::
vector
<
TensorView
<
CAST_DTYPE
>>
seq_betas
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
for
(
int
b
=
0
;
b
<
B
;
++
b
)
{
seqlogProbs
.
push_back
(
TensorView
<
const
LogProbs
<
CAST_DTYPE
>>
(
{
maxT
,
maxU
},
reinterpret_cast
<
LogProbs
<
CAST_DTYPE
>*>
(
const_cast
<
CAST_DTYPE
*>
(
logProbs
))
+
b
*
maxT
*
maxU
));
seq_betas
.
push_back
(
TensorView
<
CAST_DTYPE
>
({
maxT
,
maxU
},
betas
+
b
*
maxT
*
maxU
));
}
std
::
vector
<
CAST_DTYPE
>
scores
(
B
<<
1
);
//#pragma omp parallel for
for
(
int
i
=
0
;
i
<
B
;
++
i
)
{
// use max 2 * B threads.
ComputeBetaOneSequence
<
DTYPE
>
(
options
,
/*logProbs=*/
seqlogProbs
[
i
],
/*srcLen=*/
srcLengths
[
i
],
/*tgtLen=*/
tgtLengths
[
i
]
+
1
,
// with prepended blank.
/*betas=*/
seq_betas
[
i
]);
}
}
}
// namespace cpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/cpu/cpu_transducer.h
0 → 100644
View file @
9dcc7a15
#pragma once
#include <torchaudio/csrc/rnnt/cpu/cpu_kernels.h>
#include <torchaudio/csrc/rnnt/workspace.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
cpu
{
// Inputs:
// workspace: workspace.
// logits: pointer to (B, maxT, maxU, D) logits.
// targets: pointer to (B, maxU - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, maxT, maxU, D) gradients in the batch.
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
Compute
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
gradients
=
nullptr
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
CHECK_EQ
(
options
.
device_
,
CPU
);
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
{
// compute denominators.
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*N=*/
B
*
maxT
*
maxU
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
}
{
// compute log prob pairs.
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
());
}
{
// compute alphas and betas.
ComputeAlphasBetas
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*costs=*/
costs
);
}
if
(
gradients
!=
nullptr
)
{
ComputeGradients
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*gradients=*/
gradients
);
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeAlphas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
alphas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
CHECK_EQ
(
options
.
device_
,
CPU
);
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
{
// compute denominators.
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*N=*/
B
*
maxT
*
maxU
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
}
{
// compute log prob pairs.
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
());
}
{
// compute alphas.
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alphas=*/
alphas
);
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeBetas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
betas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
CHECK_EQ
(
options
.
device_
,
CPU
);
const
int
&
B
=
options
.
batchSize_
;
const
int
&
maxT
=
options
.
maxSrcLen_
;
const
int
&
maxU
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
{
// compute denominators.
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*N=*/
B
*
maxT
*
maxU
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
}
{
// compute log prob pairs.
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
());
}
{
// compute betas.
ComputeBetas
<
DTYPE
,
CAST_DTYPE
>
(
/*options=*/
options
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*costs=*/
costs
,
/*betas=*/
betas
);
}
return
SUCCESS
;
}
}
// namespace cpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/cpu/kernel_utils.h
0 → 100644
View file @
9dcc7a15
#pragma once
#include <cassert>
#include <torchaudio/csrc/rnnt/cpu/math.h>
namespace
torchaudio
{
namespace
rnnt
{
inline
HOST_AND_DEVICE
bool
in_range
(
int
start
,
int
end
,
// inclusive
int
val
)
{
return
start
<=
val
&&
val
<=
end
;
}
#define LOG_PROBS_SKIP_IDX 0
#define LOG_PROBS_EMIT_IDX 1
struct
Indexer2D
{
const
int
&
size2_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer2D
(
const
int
&
size2
)
:
size2_
(
size2
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
)
{
return
index1
*
size2_
+
index2
;
}
};
struct
Indexer3D
{
const
int
&
size2_
;
const
int
&
size3_
;
FORCE_INLINE
HOST_AND_DEVICE
Indexer3D
(
const
int
&
size2
,
const
int
&
size3
)
:
size2_
(
size2
),
size3_
(
size3
)
{}
FORCE_INLINE
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
)
{
return
(
index1
*
size2_
+
index2
)
*
size3_
+
index3
;
}
};
struct
Indexer4D
{
const
int
&
size2_
;
const
int
&
size3_
;
const
int
&
size4_
;
HOST_AND_DEVICE
Indexer4D
(
const
int
&
size2
,
const
int
&
size3
,
const
int
&
size4
)
:
size2_
(
size2
),
size3_
(
size3
),
size4_
(
size4
)
{}
HOST_AND_DEVICE
int
operator
()(
int
index1
,
int
index2
,
int
index3
,
int
index4
)
{
return
((
index1
*
size2_
+
index2
)
*
size3_
+
index3
)
*
size4_
+
index4
;
}
};
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/cpu/math.h
0 → 100644
View file @
9dcc7a15
#pragma once
#include <torchaudio/csrc/rnnt/macros.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
math
{
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
max
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
x
;
else
return
y
;
}
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
min
(
DTYPE
x
,
DTYPE
y
)
{
if
(
x
>
y
)
return
y
;
else
return
x
;
}
// log_sum_exp
template
<
typename
DTYPE
>
FORCE_INLINE
HOST_AND_DEVICE
DTYPE
lse
(
DTYPE
x
,
DTYPE
y
);
template
<
>
FORCE_INLINE
HOST_AND_DEVICE
float
lse
(
float
x
,
float
y
)
{
if
(
y
>
x
)
{
return
y
+
log1pf
(
expf
(
x
-
y
));
}
else
{
return
x
+
log1pf
(
expf
(
y
-
x
));
}
}
}
// namespace math
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/compute.cu
0 → 100644
View file @
9dcc7a15
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
// Entry point into RNNT Loss
std
::
tuple
<
torch
::
Tensor
,
c10
::
optional
<
torch
::
Tensor
>>
compute
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
TORCH_CHECK
(
logits
.
device
().
type
()
==
targets
.
device
().
type
(),
"logits and targets must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
logit_lengths
.
device
().
type
(),
"logits and logit_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
device
().
type
()
==
target_lengths
.
device
().
type
(),
"logits and target_lengths must be on the same device"
);
TORCH_CHECK
(
logits
.
dtype
()
==
torch
::
kFloat32
||
logits
.
dtype
()
==
torch
::
kFloat16
,
"logits must be float32 or float16 (half) type"
);
TORCH_CHECK
(
targets
.
dtype
()
==
torch
::
kInt32
,
"targets must be int32 type"
);
TORCH_CHECK
(
logit_lengths
.
dtype
()
==
torch
::
kInt32
,
"logit_lengths must be int32 type"
);
TORCH_CHECK
(
target_lengths
.
dtype
()
==
torch
::
kInt32
,
"target_lengths must be int32 type"
);
TORCH_CHECK
(
logits
.
is_contiguous
(),
"logits must be contiguous"
);
TORCH_CHECK
(
targets
.
is_contiguous
(),
"targets must be contiguous"
);
TORCH_CHECK
(
logit_lengths
.
is_contiguous
(),
"logit_lengths must be contiguous"
);
TORCH_CHECK
(
target_lengths
.
is_contiguous
(),
"target_lengths must be contiguous"
);
TORCH_CHECK
(
logits
.
dim
()
==
4
,
"logits must be 4-D (batch, time, target, class)"
);
TORCH_CHECK
(
targets
.
dim
()
==
2
,
"targets must be 2-D (batch, max target length)"
);
TORCH_CHECK
(
logit_lengths
.
dim
()
==
1
,
"logit_lengths must be 1-D"
);
TORCH_CHECK
(
target_lengths
.
dim
()
==
1
,
"target_lengths must be 1-D"
);
TORCH_CHECK
(
logit_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and logit_lengths"
);
TORCH_CHECK
(
target_lengths
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and target_lengths"
);
TORCH_CHECK
(
targets
.
size
(
0
)
==
logits
.
size
(
0
),
"batch dimension mismatch between logits and targets"
);
TORCH_CHECK
(
blank
>=
0
&&
blank
<
logits
.
size
(
-
1
),
"blank must be within [0, logits.shape[-1])"
);
TORCH_CHECK
(
logits
.
size
(
1
)
==
at
::
max
(
logit_lengths
).
item
().
toInt
(),
"input length mismatch"
);
TORCH_CHECK
(
logits
.
size
(
2
)
==
at
::
max
(
target_lengths
).
item
().
toInt
()
+
1
,
"output length mismatch"
);
TORCH_CHECK
(
targets
.
size
(
1
)
==
at
::
max
(
target_lengths
).
item
().
toInt
(),
"target length mismatch"
);
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
options
.
batchSize_
*
options
.
nHypos_
,
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
c10
::
optional
<
torch
::
Tensor
>
gradients
=
torch
::
zeros_like
(
logits
);
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
switch
(
logits
.
scalar_type
())
{
case
torch
::
ScalarType
::
Float
:
{
Compute
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*gradients=*/
gradients
->
data_ptr
<
float
>
());
break
;
}
case
torch
::
ScalarType
::
Half
:
{
Compute
<
/*DTYPE=*/
c10
::
Half
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
c10
::
Half
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
c10
::
Half
>
(),
/*gradients=*/
gradients
->
data_ptr
<
c10
::
Half
>
());
break
;
}
default:
{
break
;
}
};
return
std
::
make_tuple
(
costs
,
gradients
);
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss"
,
&
compute
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/compute_alphas.cu
0 → 100644
View file @
9dcc7a15
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_alphas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
alphas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeAlphas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*alphas=*/
alphas
.
data_ptr
<
float
>
());
return
alphas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_alphas"
,
&
compute_alphas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/compute_betas.cu
0 → 100644
View file @
9dcc7a15
#include <c10/cuda/CUDAStream.h>
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_transducer.h>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
torch
::
Tensor
compute_betas
(
const
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
targets
,
const
torch
::
Tensor
&
logit_lengths
,
const
torch
::
Tensor
&
target_lengths
,
int64_t
blank
,
double
clamp
)
{
Options
options
;
options
.
batchSize_
=
logit_lengths
.
size
(
0
);
options
.
nHypos_
=
target_lengths
.
size
(
0
)
/
logit_lengths
.
size
(
0
);
options
.
maxSrcLen_
=
logits
.
size
(
1
);
options
.
maxTgtLen_
=
logits
.
size
(
2
);
options
.
numTargets_
=
logits
.
size
(
3
);
options
.
blank_
=
blank
;
options
.
clamp_
=
clamp
;
CHECK_EQ
(
logits
.
device
().
type
(),
torch
::
DeviceType
::
CUDA
);
options
.
stream_
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaSetDevice
(
logits
.
get_device
());
options
.
device_
=
GPU
;
torch
::
Tensor
costs
=
torch
::
empty
(
target_lengths
.
size
(
0
),
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
betas
=
torch
::
zeros
(
{
options
.
batchSize_
*
options
.
nHypos_
,
options
.
maxSrcLen_
,
options
.
maxTgtLen_
},
torch
::
TensorOptions
().
device
(
logits
.
device
()).
dtype
(
logits
.
dtype
()));
torch
::
Tensor
int_workspace
=
torch
::
empty
(
IntWorkspace
::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Int
));
torch
::
Tensor
float_workspace
=
torch
::
empty
(
DtypeWorkspace
<
float
>::
ComputeSizeFromOptions
(
options
),
torch
::
TensorOptions
()
.
device
(
logits
.
device
())
.
dtype
(
torch
::
ScalarType
::
Float
));
Workspace
<
float
>
workspace
(
/*options=*/
options
,
/*dtype_data=*/
float_workspace
.
data_ptr
<
float
>
(),
/*dtype_size=*/
float_workspace
.
numel
(),
/*int_data=*/
int_workspace
.
data_ptr
<
int
>
(),
/*int_size=*/
int_workspace
.
numel
());
// Only support float, this is mainly to enable easy
// unit-testing
ComputeBetas
<
/*DTYPE=*/
float
,
/*CAST_DTYPE=*/
float
>
(
/*workspace=*/
workspace
,
/*logits=*/
logits
.
data_ptr
<
float
>
(),
/*targets=*/
targets
.
data_ptr
<
int
>
(),
/*logit_lengths=*/
logit_lengths
.
data_ptr
<
int
>
(),
/*target_lengths=*/
target_lengths
.
data_ptr
<
int
>
(),
/*costs=*/
costs
.
data_ptr
<
float
>
(),
/*betas=*/
betas
.
data_ptr
<
float
>
());
return
betas
;
}
TORCH_LIBRARY_IMPL
(
torchaudio
,
CUDA
,
m
)
{
m
.
impl
(
"rnnt_loss_betas"
,
&
compute_betas
);
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh
0 → 100644
View file @
9dcc7a15
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
// each thread reduces one matrix row
int
offset
=
blockIdx
.
x
*
dim
;
// [n, 0]
CAST_DTYPE
val
=
inputs
[
offset
];
// default = inputs(n, 0)
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
CAST_DTYPE
next
=
inputs
[
offset
+
d
];
if
(
next
>
val
)
{
val
=
next
;
}
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shared
[
threadIdx
.
x
+
stride
]
>
shared
[
threadIdx
.
x
])
{
shared
[
threadIdx
.
x
]
=
shared
[
threadIdx
.
x
+
stride
];
val
=
shared
[
threadIdx
.
x
];
}
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down_sync
(
0xFFFFFFFF
,
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
if
(
shf
>
val
)
{
val
=
shf
;
}
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
val
;
}
}
template
<
int
NUM_THREADS
,
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ReduceLogSumExpGivenMax2D
(
int
dim
,
const
DTYPE
*
inputs
,
// [N, dim]
CAST_DTYPE
*
outputs
)
{
// in: max -> out: logsum
__shared__
CAST_DTYPE
shared
[
NUM_THREADS
];
CAST_DTYPE
max
=
outputs
[
blockIdx
.
x
];
CAST_DTYPE
val
=
0
;
int
offset
=
blockIdx
.
x
*
dim
;
for
(
int
d
=
threadIdx
.
x
;
d
<
dim
;
d
+=
NUM_THREADS
)
{
val
=
val
+
std
::
exp
(
CAST_DTYPE
(
inputs
[
offset
+
d
])
-
max
);
}
shared
[
threadIdx
.
x
]
=
val
;
__syncthreads
();
for
(
int
stride
=
(
NUM_THREADS
>>
1
);
stride
>=
WARP_SIZE
;
stride
>>=
1
)
{
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
shared
[
threadIdx
.
x
]
+
shared
[
threadIdx
.
x
+
stride
];
shared
[
threadIdx
.
x
]
=
val
;
}
__syncthreads
();
}
CAST_DTYPE
shf
;
for
(
int
stride
=
(
WARP_SIZE
>>
1
);
stride
>
0
;
stride
>>=
1
)
{
shf
=
__shfl_down_sync
(
0xFFFFFFFF
,
val
,
stride
);
if
(
threadIdx
.
x
<
stride
&&
threadIdx
.
x
+
stride
<
dim
)
{
val
=
val
+
shf
;
}
}
if
(
threadIdx
.
x
==
0
)
{
outputs
[
blockIdx
.
x
]
=
max
+
std
::
log
(
val
);
}
}
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_CUDA
torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh
0 → 100644
View file @
9dcc7a15
#pragma once
#ifdef USE_CUDA
#include <cassert>
#include <torchaudio/csrc/rnnt/gpu/kernel_utils.h>
#include <torchaudio/csrc/rnnt/gpu/kernels.h>
#include <torchaudio/csrc/rnnt/gpu/math.cuh>
namespace
torchaudio
{
namespace
rnnt
{
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeLogProbs
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
CAST_DTYPE
*
logProbs
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
&
D
=
numTargets
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
Indexer3D
indexer
(
maxT
,
maxU
);
int
idx
=
indexer
(
bTgt
,
t
,
u
);
// skip: log_prob(b, t, u).skip() = logits(b, t, u, blank) - denom(b, t, u).
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_SKIP_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
blank
])
-
denominators
[
idx
];
if
(
u
<
U
-
1
)
{
// emit: log_prob(b, t, u).emit() = logits(b, t, u, tgt[u]) - denom(b, t,
// u).
int
target
=
targets
[
Indexer2D
(
maxU
-
1
)(
bTgt
,
u
)];
logProbs
[(
idx
<<
1
)
+
LOG_PROBS_EMIT_IDX
]
=
CAST_DTYPE
(
logits
[
idx
*
D
+
target
])
-
denominators
[
idx
];
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeAlphas
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
+
1
;
const
int
u
=
blockIdx
.
y
+
1
;
if
(
t
>=
T
||
u
>=
U
)
{
// out of boundary.
return
;
}
int
*
counter
=
alpha_counters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
1
&&
u
==
1
)
{
alphas
[
idxr
(
bTgt
,
0
,
0
)]
=
0
;
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
1
&&
u
<
U
)
{
// alpha(0, u) = alpha(0, u - 1) + logProbs(0, u - 1).emit().
alphas
[
idxr
(
bTgt
,
0
,
u
)]
=
alphas
[
idxr
(
bTgt
,
0
,
u
-
1
)]
+
logProbs
[(
idxr
(
bTgt
,
0
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
<
T
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
0
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
val
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
0
)];
alphas
[
idxr
(
bTgt
,
t
,
0
)]
=
skip_prob
+
val
;
}
if
(
t
<
T
&&
u
<
U
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
-
1
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
-
1
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
alphas
[
idxr
(
bTgt
,
blockIdx
.
x
*
blockDim
.
x
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
alphas
[
idxr
(
bTgt
,
t
,
u
-
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
alphas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__device__
void
ComputeBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
const
int
&
maxT
=
maxSrcLen
;
const
int
&
maxU
=
maxTgtLen
;
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
bSrc
=
bTgt
/
H
;
const
int
T
=
srcLengths
[
bSrc
];
const
int
U
=
tgtLengths
[
bTgt
]
+
1
;
const
int
t
=
T
-
2
-
blockIdx
.
x
*
blockDim
.
x
-
threadIdx
.
x
;
const
int
u
=
U
-
2
-
blockIdx
.
y
;
if
(
t
<
0
||
u
<
0
)
{
// out of boundary.
return
;
}
int
*
counter
=
betaCounters
+
Indexer2D
(
maxU
)(
bTgt
,
blockIdx
.
y
);
Indexer3D
idxr
(
maxT
,
maxU
);
if
(
t
==
T
-
2
&&
u
==
U
-
2
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
U
-
1
)]
=
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
}
if
(
blockIdx
.
x
>
0
)
{
// wait for previous warp (in t-axis) is ready.
while
(
atomicAdd
(
counter
,
0
)
<
blockIdx
.
x
)
{
}
}
if
(
blockIdx
.
y
>
0
)
{
// wait for previous warp (in u-axis) is ready.
while
(
atomicAdd
(
counter
-
1
,
0
)
<=
blockIdx
.
x
)
{
}
}
if
(
t
==
T
-
2
&&
u
>=
0
)
{
betas
[
idxr
(
bTgt
,
T
-
1
,
u
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
,
u
+
1
)]
+
logProbs
[(
idxr
(
bTgt
,
T
-
1
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
}
if
(
blockIdx
.
y
==
0
&&
t
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
U
-
1
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
val
;
#pragma unroll
for
(
int
i
=
1
;
i
<
warpSize
;
i
<<=
1
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
skip_prob
,
i
);
if
(
i
<=
threadIdx
.
x
)
{
skip_prob
=
skip_prob
+
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
U
-
1
)]
=
betas
[
idxr
(
bTgt
,
T
-
1
-
blockIdx
.
x
*
blockDim
.
x
,
U
-
1
)]
+
skip_prob
;
}
if
(
t
>=
0
&&
u
>=
0
)
{
CAST_DTYPE
skip_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_SKIP_IDX
];
CAST_DTYPE
emit_prob
=
logProbs
[(
idxr
(
bTgt
,
t
,
u
)
<<
1
)
+
LOG_PROBS_EMIT_IDX
];
CAST_DTYPE
skip
=
betas
[
idxr
(
bTgt
,
t
+
threadIdx
.
x
+
1
,
u
)]
+
skip_prob
;
CAST_DTYPE
emit
=
betas
[
idxr
(
bTgt
,
t
,
u
+
1
)]
+
emit_prob
;
CAST_DTYPE
val
=
math
::
lse
(
skip
,
emit
);
CAST_DTYPE
out
=
val
;
for
(
int
i
=
1
;
i
<
warpSize
;
++
i
)
{
val
=
__shfl_up_sync
(
0xffffffff
,
val
,
1
);
if
(
i
==
threadIdx
.
x
)
{
val
=
math
::
lse
(
val
+
skip_prob
,
emit
);
out
=
val
;
}
}
betas
[
idxr
(
bTgt
,
t
,
u
)]
=
out
;
if
(
t
==
0
&&
u
==
0
)
{
// use -beta(0, 0) as cost.
costs
[
bTgt
]
=
DTYPE
(
-
out
);
}
}
if
(
threadIdx
.
x
==
0
)
{
__threadfence
();
atomicAdd
(
counter
,
1
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasBetasCosts
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
warpSize
=
0
,
int
numWarps
=
0
,
int
H
=
1
)
{
assert
(
threadIdx
.
y
==
0
||
threadIdx
.
y
==
1
);
if
(
threadIdx
.
y
==
0
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
alpha_counters
,
/*alphas=*/
alphas
,
H
);
}
else
{
// threadIdx.y == 1
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
/*maxSrcLen=*/
maxSrcLen
,
/*maxTgtLen=*/
maxTgtLen
,
/*numTargets=*/
numTargets
,
/*blank=*/
blank
,
/*logProbs=*/
logProbs
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*betaCounters=*/
betaCounters
,
/*beta=*/
betas
,
/*costs=*/
costs
,
H
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeGradients
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
CAST_DTYPE
clamp
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
const
CAST_DTYPE
*
denominators
,
const
CAST_DTYPE
*
alphas
,
const
CAST_DTYPE
*
betas
,
DTYPE
*
gradients
,
int
H
=
1
)
{
const
int
bTgt
=
blockIdx
.
z
;
// 0 <= b < B
const
int
t
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
u
=
blockIdx
.
y
;
ComputeGradientsElement
(
bTgt
,
t
,
u
,
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
clamp
,
logits
,
targets
,
srcLengths
,
tgtLengths
,
denominators
,
alphas
,
betas
,
gradients
,
H
);
}
// This is a __global__ wrapper around ComputeAlphas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeAlphasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
alpha_counters
,
volatile
CAST_DTYPE
*
alphas
,
int
H
=
1
)
{
ComputeAlphas
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
alpha_counters
,
alphas
,
H
);
}
// This is a __global__ wrapper around ComputeBetas
// device kernel to enable unit testing
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
__global__
void
ComputeBetasWrapper
(
int
maxSrcLen
,
int
maxTgtLen
,
int
numTargets
,
int
blank
,
const
CAST_DTYPE
*
logProbs
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
int
*
betaCounters
,
volatile
CAST_DTYPE
*
betas
,
DTYPE
*
costs
,
int
H
=
1
)
{
ComputeBetasCosts
<
DTYPE
,
CAST_DTYPE
>
(
maxSrcLen
,
maxTgtLen
,
numTargets
,
blank
,
logProbs
,
srcLengths
,
tgtLengths
,
betaCounters
,
betas
,
costs
,
H
);
}
// #undef LOG_PROBS_SKIP_IDX
// #undef LOG_PROBS_EMIT_IDX
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_CUDA
torchaudio/csrc/rnnt/gpu/gpu_transducer.h
0 → 100644
View file @
9dcc7a15
#pragma once
#ifdef USE_CUDA
#include <torchaudio/csrc/rnnt/workspace.h>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernel_utils.cuh>
#include <torchaudio/csrc/rnnt/gpu/gpu_kernels.cuh>
namespace
torchaudio
{
namespace
rnnt
{
namespace
gpu
{
#define gpuErrchk(ans) \
{ gpuAssert((ans), __FILE__, __LINE__); }
inline
void
gpuAssert
(
cudaError_t
code
,
const
char
*
file
,
int
line
,
bool
abort
=
true
)
{
if
(
code
!=
cudaSuccess
)
{
fprintf
(
stderr
,
"
\n
GPUassert: %s %s %d
\n
"
,
cudaGetErrorString
(
code
),
file
,
line
);
if
(
abort
)
exit
(
code
);
}
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
LogSumExp2D
(
cudaStream_t
stream
,
int
N
,
int
D
,
const
DTYPE
*
logits
,
// [N, D]
CAST_DTYPE
*
outputs
)
{
{
// compute max among D.
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
// BUGBUG: These error codes are only accurate when launching with
// blocking. Otherwise they usually reflect earlier errors.
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_MAX_FAILED
;
}
}
{
// compute log(sum(exp(d_i - max)))
dim3
block_dims
(
N
);
dim3
thread_dims
(
REDUCE_THREADS
);
ReduceLogSumExpGivenMax2D
<
REDUCE_THREADS
,
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*dim=*/
D
,
/*inputs=*/
logits
,
/*outputs=*/
outputs
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_DENOMINATOR_REDUCE_SUM_FAILED
;
}
}
return
SUCCESS
;
}
// Inputs:
// workspace: workspace.
// logits: pointer to (B, max_T, max_U, D) logits.
// targets: pointer to (B, max_U - 1) targets in the batch.
// srcLengths: pointer to (B, ) source lengths in the batch.
// tgtLengths: pointer to (B, ) target lengths in the batch.
//
// Outputs:
// costs: pointer to (B, ) costs in the batch.
// gradients: pointer to (B, max_T, max_U, D) gradients in the batch.
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
Compute
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
gradients
=
nullptr
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
cudaStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
const
CAST_DTYPE
clamp
=
options
.
clamp_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas, betas and costs.
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B * H blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 2. 1 for alpha, 1 for beta
dim3
thread_dims
(
WARP_SIZE
,
2
);
ComputeAlphasBetasCosts
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*beta_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*costs=*/
costs
,
/*warp_size=*/
WARP_SIZE
,
/*num_warps=*/
num_warps
,
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
if
(
gradients
!=
nullptr
)
{
// compute gradients.
// don't set gradients to zero to here as gradients might reuse memory from
// logits
int
num_blocks
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_blocks
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeGradients
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*clamp=*/
clamp
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*alphas=*/
workspace
.
GetPointerToAlphas
(),
/*betas=*/
workspace
.
GetPointerToBetas
(),
/*gradients=*/
gradients
,
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_GRADIENTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeAlphas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
alphas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
cudaStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute alphas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for alpha only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeAlphasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToAlphaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
alphas
,
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
template
<
typename
DTYPE
,
typename
CAST_DTYPE
>
status_t
ComputeBetas
(
const
Workspace
<
CAST_DTYPE
>&
workspace
,
const
DTYPE
*
logits
,
const
int
*
targets
,
const
int
*
srcLengths
,
const
int
*
tgtLengths
,
DTYPE
*
costs
,
DTYPE
*
betas
)
{
const
Options
&
options
=
workspace
.
GetOptions
();
const
cudaStream_t
&
stream
=
options
.
stream_
;
const
int
&
B
=
options
.
batchSize_
;
const
int
&
H
=
options
.
nHypos_
;
const
int
&
max_T
=
options
.
maxSrcLen_
;
const
int
&
max_U
=
options
.
maxTgtLen_
;
const
int
&
D
=
options
.
numTargets_
;
const
int
&
blank
=
options
.
blank_
;
{
// compute denominators.
status_t
status
=
LogSumExp2D
<
DTYPE
,
CAST_DTYPE
>
(
/*stream=*/
stream
,
/*N=*/
B
*
H
*
max_T
*
max_U
,
/*D=*/
D
,
/*logits=*/
logits
,
/*denominators=*/
workspace
.
GetPointerToDenominators
());
if
(
status
!=
SUCCESS
)
{
return
status
;
}
}
{
// compute log probability pairs (blank and target).
int
num_segments
=
(
max_T
+
MAX_THREADS_PER_BLOCK
-
1
)
/
MAX_THREADS_PER_BLOCK
;
dim3
block_dims
(
num_segments
,
max_U
,
B
*
H
);
dim3
thread_dims
(
MAX_THREADS_PER_BLOCK
);
ComputeLogProbs
<
DTYPE
,
CAST_DTYPE
><<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*logits=*/
logits
,
/*targets=*/
targets
,
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*denominators=*/
workspace
.
GetPointerToDenominators
(),
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_LOG_PROBS_FAILED
;
}
}
{
// compute betas
// warp is usually a group of threads (32)
int
num_warps
=
(
max_T
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
// each block is identified by 3 d tuple.
// we are using num_warp * max_U * B blocks
// where num_warp is division among Time axis
dim3
block_dims
(
num_warps
,
max_U
,
B
*
H
);
// each thread is identified by a 2 d tuple
// 2nd dim is 1 for betas only
dim3
thread_dims
(
WARP_SIZE
,
1
);
ComputeBetasWrapper
<
DTYPE
,
CAST_DTYPE
>
<<<
block_dims
,
thread_dims
,
0
,
stream
>>>
(
/*max_src_len=*/
max_T
,
/*max_tgt_len=*/
max_U
,
/*num_targets=*/
D
,
/*blank=*/
blank
,
/*log_probs=*/
workspace
.
GetPointerToLogProbs
(),
/*srcLengths=*/
srcLengths
,
/*tgtLengths=*/
tgtLengths
,
/*alpha_counters=*/
workspace
.
GetPointerToBetaCounters
(),
/*alphas=*/
(
volatile
DTYPE
*
)
betas
,
costs
,
H
);
if
(
cudaGetLastError
()
!=
cudaSuccess
)
{
return
COMPUTE_ALPHAS_BETAS_COSTS_FAILED
;
}
}
return
SUCCESS
;
}
}
// namespace gpu
}
// namespace rnnt
}
// namespace torchaudio
#endif // USE_CUDA
Prev
1
…
15
16
17
18
19
20
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment