Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
78d41d57
Unverified
Commit
78d41d57
authored
Sep 25, 2021
by
moto
Committed by
GitHub
Sep 25, 2021
Browse files
[doc] Fix return type of wav2vec2 model (#1790)
parent
b2e9f1e4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
torchaudio/models/wav2vec2/model.py
torchaudio/models/wav2vec2/model.py
+12
-12
No files found.
torchaudio/models/wav2vec2/model.py
View file @
78d41d57
...
@@ -59,13 +59,13 @@ class Wav2Vec2Model(Module):
...
@@ -59,13 +59,13 @@ class Wav2Vec2Model(Module):
intermediate layers are returned.
intermediate layers are returned.
Returns:
Returns:
List of Tensor:
List of
Tensors and an optional
Tensor:
Features from corresponding laye
rs
.
List of Tenso
rs
Shape: ``(batch, frames, feature dimention)``
Features from requested layers.
Tensor, op
tion
al:
Each Tensor is of shape: ``(batch, frames, feature dimen
tion
)``
Indicates the valid length of each feature in the batch, computed
Tensor or None
based on the given
``lengths`` argument
.
If
``lengths`` argument
was provided, a Tensor of shape ``(batch, )``
Shape: ``(batch, )``
.
is retuned. It indicates the valid length of each feature in the batch
.
"""
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
x
=
self
.
encoder
.
extract_features
(
x
,
lengths
,
num_layers
)
x
=
self
.
encoder
.
extract_features
(
x
,
lengths
,
num_layers
)
...
@@ -85,13 +85,13 @@ class Wav2Vec2Model(Module):
...
@@ -85,13 +85,13 @@ class Wav2Vec2Model(Module):
Shape: ``(batch, )``.
Shape: ``(batch, )``.
Returns:
Returns:
Tensor:
Tensor and an optional Tensor:
Tensor
The sequences of probability distribution (in logit) over labels.
The sequences of probability distribution (in logit) over labels.
Shape: ``(batch, frames, num labels)``.
Shape: ``(batch, frames, num labels)``.
Tensor, optional:
Tensor or None
Indicates the valid length of each feature in the batch, computed
If ``lengths`` argument was provided, a Tensor of shape ``(batch, )``
based on the given ``lengths`` argument.
is retuned. It indicates the valid length of each feature in the batch.
Shape: ``(batch, )``.
"""
"""
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
x
,
lengths
=
self
.
feature_extractor
(
waveforms
,
lengths
)
x
=
self
.
encoder
(
x
,
lengths
)
x
=
self
.
encoder
(
x
,
lengths
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment