Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
bf580c75
"torchvision/vscode:/vscode.git/clone" did not exist on "d88d8961ae51507d0cb680329d985b1488b1b76b"
Unverified
Commit
bf580c75
authored
Oct 13, 2021
by
nateanl
Committed by
GitHub
Oct 13, 2021
Browse files
Refactor transforms.Fade on GPU computation (#1871)
parent
25a8adf6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
8 deletions
+11
-8
torchaudio/transforms.py
torchaudio/transforms.py
+11
-8
No files found.
torchaudio/transforms.py
View file @
bf580c75
...
@@ -1025,12 +1025,15 @@ class Fade(torch.nn.Module):
...
@@ -1025,12 +1025,15 @@ class Fade(torch.nn.Module):
"""
"""
waveform_length
=
waveform
.
size
()[
-
1
]
waveform_length
=
waveform
.
size
()[
-
1
]
device
=
waveform
.
device
device
=
waveform
.
device
return
self
.
_fade_in
(
waveform_length
).
to
(
device
)
*
\
return
(
self
.
_fade_out
(
waveform_length
).
to
(
device
)
*
waveform
self
.
_fade_in
(
waveform_length
,
device
)
*
self
.
_fade_out
(
waveform_length
,
device
)
*
waveform
)
def
_fade_in
(
self
,
waveform_length
:
int
)
->
Tensor
:
def
_fade_in
(
self
,
waveform_length
:
int
,
device
:
torch
.
device
)
->
Tensor
:
fade
=
torch
.
linspace
(
0
,
1
,
self
.
fade_in_len
)
fade
=
torch
.
linspace
(
0
,
1
,
self
.
fade_in_len
,
device
=
device
)
ones
=
torch
.
ones
(
waveform_length
-
self
.
fade_in_len
)
ones
=
torch
.
ones
(
waveform_length
-
self
.
fade_in_len
,
device
=
device
)
if
self
.
fade_shape
==
"linear"
:
if
self
.
fade_shape
==
"linear"
:
fade
=
fade
fade
=
fade
...
@@ -1049,9 +1052,9 @@ class Fade(torch.nn.Module):
...
@@ -1049,9 +1052,9 @@ class Fade(torch.nn.Module):
return
torch
.
cat
((
fade
,
ones
)).
clamp_
(
0
,
1
)
return
torch
.
cat
((
fade
,
ones
)).
clamp_
(
0
,
1
)
def
_fade_out
(
self
,
waveform_length
:
int
)
->
Tensor
:
def
_fade_out
(
self
,
waveform_length
:
int
,
device
:
torch
.
device
)
->
Tensor
:
fade
=
torch
.
linspace
(
0
,
1
,
self
.
fade_out_len
)
fade
=
torch
.
linspace
(
0
,
1
,
self
.
fade_out_len
,
device
=
device
)
ones
=
torch
.
ones
(
waveform_length
-
self
.
fade_out_len
)
ones
=
torch
.
ones
(
waveform_length
-
self
.
fade_out_len
,
device
=
device
)
if
self
.
fade_shape
==
"linear"
:
if
self
.
fade_shape
==
"linear"
:
fade
=
-
fade
+
1
fade
=
-
fade
+
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment