Unverified Commit 4da84008 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Finish dataloader integration (#24201)

parent 0675600a
......@@ -176,7 +176,6 @@ if is_datasets_available():
if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available():
dep_version_check("fairscale")
......@@ -1762,10 +1761,6 @@ class Trainer:
total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
if is_torch_tpu_available():
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = parallel_loader
else:
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary.
......@@ -3088,9 +3083,6 @@ class Trainer:
# Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None)
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if args.past_index >= 0:
self._past = None
......@@ -3694,9 +3686,6 @@ class Trainer:
model.eval()
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if args.past_index >= 0:
self._past = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment