Commit 8682b644 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Replace bg_iterator in examples (#2645)

Summary:
`bg_iterator` was deprecated in 0.11 because it was known to have issues (deadlock) without speed up. Remove instances of `bg_iterator` used in torchaudio examples.

Resolves https://github.com/pytorch/audio/issues/2642

Pull Request resolved: https://github.com/pytorch/audio/pull/2645

Reviewed By: nateanl

Differential Revision: D38954292

Pulled By: carolineechen

fbshipit-source-id: 2333ab5228c2b8511ff532057543aaf9d02b2789
parent c7e0595b
......@@ -13,7 +13,6 @@ from languagemodels import LanguageModel
from torch.optim import Adadelta, Adam, AdamW, SGD
from torch.optim.lr_scheduler import ExponentialLR, ReduceLROnPlateau
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.functional import edit_distance
from torchaudio.models.wav2letter import Wav2Letter
from transforms import Normalize, UnsqueezeFirst
......@@ -246,7 +245,7 @@ def train_one_epoch(
metric = MetricLogger("train", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2):
for inputs, targets, tensors_lengths, target_lengths in data_loader:
start = time()
inputs = inputs.to(device, non_blocking=True)
......@@ -314,7 +313,7 @@ def evaluate(
metric = MetricLogger("validation", disable=disable_logger)
metric["epoch"] = epoch
for inputs, targets, tensors_lengths, target_lengths in bg_iterator(data_loader, maxsize=2):
for inputs, targets, tensors_lengths, target_lengths in data_loader:
inputs = inputs.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
......
......@@ -13,7 +13,6 @@ from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchaudio.datasets.utils import bg_iterator
from torchaudio.models.wavernn import WaveRNN
from utils import count_parameters, MetricLogger, save_checkpoint
......@@ -209,7 +208,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):
metric = MetricLogger("train_iteration")
metric["epoch"] = epoch
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2):
for waveform, specgram, target in data_loader:
start2 = time()
......@@ -258,7 +257,7 @@ def validate(model, criterion, data_loader, device, epoch):
sums = defaultdict(lambda: 0.0)
start = time()
for waveform, specgram, target in bg_iterator(data_loader, maxsize=2):
for waveform, specgram, target in data_loader:
waveform = waveform.to(device)
specgram = specgram.to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment