Unverified Commit 4da84008 authored by Zach Mueller's avatar Zach Mueller Committed by GitHub
Browse files

Finish dataloader integration (#24201)

parent 0675600a
...@@ -176,7 +176,6 @@ if is_datasets_available(): ...@@ -176,7 +176,6 @@ if is_datasets_available():
if is_torch_tpu_available(check_device=False): if is_torch_tpu_available(check_device=False):
import torch_xla.core.xla_model as xm import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
if is_fairscale_available(): if is_fairscale_available():
dep_version_check("fairscale") dep_version_check("fairscale")
...@@ -1762,11 +1761,7 @@ class Trainer: ...@@ -1762,11 +1761,7 @@ class Trainer:
total_batched_samples = 0 total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs): for epoch in range(epochs_trained, num_train_epochs):
if is_torch_tpu_available(): epoch_iterator = train_dataloader
parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device)
epoch_iterator = parallel_loader
else:
epoch_iterator = train_dataloader
# Reset the past mems state at the beginning of each epoch if necessary. # Reset the past mems state at the beginning of each epoch if necessary.
if args.past_index >= 0: if args.past_index >= 0:
...@@ -3088,9 +3083,6 @@ class Trainer: ...@@ -3088,9 +3083,6 @@ class Trainer:
# Do this before wrapping. # Do this before wrapping.
eval_dataset = getattr(dataloader, "dataset", None) eval_dataset = getattr(dataloader, "dataset", None)
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if args.past_index >= 0: if args.past_index >= 0:
self._past = None self._past = None
...@@ -3694,9 +3686,6 @@ class Trainer: ...@@ -3694,9 +3686,6 @@ class Trainer:
model.eval() model.eval()
if is_torch_tpu_available():
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
if args.past_index >= 0: if args.past_index >= 0:
self._past = None self._past = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment