Commit 8682b644 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Replace bg_iterator in examples (#2645)

Summary:
`bg_iterator` was deprecated in 0.11 because it was known to have issues (deadlock) without speed up. Remove instances of `bg_iterator` used in torchaudio examples.

Resolves https://github.com/pytorch/audio/issues/2642

Pull Request resolved: https://github.com/pytorch/audio/pull/2645

Reviewed By: nateanl

Differential Revision: D38954292

Pulled By: carolineechen

fbshipit-source-id: 2333ab5228c2b8511ff532057543aaf9d02b2789
parent c7e0595b
...@@ -13,7 +13,6 @@ from languagemodels import LanguageModel ...@@ -13,7 +13,6 @@ from languagemodels import LanguageModel
from torch.optim import Adadelta, Adam, AdamW, SGD from torch.optim import Adadelta, Adam, AdamW, SGD
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.functional import edit_distance from torchaudio.functional import edit_distance
from torchaudio.models.wav2letter import Wav2Letter from torchaudio.models.wav2letter import Wav2Letter
from transforms import Normalize, UnsqueezeFirst from transforms import Normalize, UnsqueezeFirst
...@@ -246,7 +245,7 @@ def train_one_epoch( ...@@ -246,7 +245,7 @@ def train_one_epoch(
metric = MetricLogger("train", disable=disable_logger) metric = MetricLogger("train", disable=disable_logger)
metric["epoch"] = epoch metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2): for inputs, targets, tensors_lengths, target_lengths in data_loader:
start = time() start = time()
inputs = inputs.to(device, non_blocking=True) inputs = inputs.to(device, non_blocking=True)
...@@ -314,7 +313,7 @@ def evaluate( ...@@ -314,7 +313,7 @@ def evaluate(
metric = MetricLogger("validation", disable=disable_logger) metric = MetricLogger("validation", disable=disable_logger)
metric["epoch"] = epoch metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2): for inputs, targets, tensors_lengths, target_lengths in data_loader:
inputs = inputs.to(device, non_blocking=True) inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True) targets = targets.to(device, non_blocking=True)
......
...@@ -13,7 +13,6 @@ from losses import LongCrossEntropyLoss, MoLLoss ...@@ -13,7 +13,6 @@ from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB from processing import NormalizeDB
from torch.optim import Adam from torch.optim import Adam
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wavernn import WaveRNN from torchaudio.models.wavernn import WaveRNN
from utils import count_parameters, MetricLogger, save_checkpoint from utils import count_parameters, MetricLogger, save_checkpoint
...@@ -209,7 +208,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch): ...@@ -209,7 +208,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
metric = MetricLogger("train_iteration") metric = MetricLogger("train_iteration")
metric["epoch"] = epoch metric["epoch"] = epoch
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): for waveform, specgram, target in data_loader:
start2 = time() start2 = time()
...@@ -258,7 +257,7 @@ def validate(model, criterion, data_loader, device, epoch): ...@@ -258,7 +257,7 @@ def validate(model, criterion, data_loader, device, epoch):
sums = defaultdict(lambda: 0.0) sums = defaultdict(lambda: 0.0)
start = time() start = time()
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2): for waveform, specgram, target in data_loader:
waveform = waveform.to(device) waveform = waveform.to(device)
specgram = specgram.to(device) specgram = specgram.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment