"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "14fc1a246761c563b7cfc984b87356b14df1e8c5"
Unverified Commit 2890116a authored by Matthew Hoffman's avatar Matthew Hoffman Committed by GitHub
Browse files

Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor (#29447)

* Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor

dataloader_prefetch_factor was added to TrainingArguments in #28498 with the default value None, but  versions of torch<2.0.0 do not accept None and will raise an error if num_workers == 0 and prefetch_factor != 2

* Add is_torch_available() check

* Use is_torch_greater_or_equal_than_2_0

add back check for dataloader_prefetch_factor
parent b27aa206
...@@ -66,6 +66,8 @@ if is_torch_available(): ...@@ -66,6 +66,8 @@ if is_torch_available():
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .pytorch_utils import is_torch_greater_or_equal_than_2_0
if is_accelerate_available(): if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState from accelerate.state import AcceleratorState, PartialState
from accelerate.utils import DistributedType from accelerate.utils import DistributedType
...@@ -1023,13 +1025,13 @@ class TrainingArguments: ...@@ -1023,13 +1025,13 @@ class TrainingArguments:
) )
}, },
) )
dataloader_prefetch_factor: int = field( dataloader_prefetch_factor: Optional[int] = field(
default=None, default=None if not is_torch_available() or is_torch_greater_or_equal_than_2_0 else 2,
metadata={ metadata={
"help": ( "help": (
"Number of batches loaded in advance by each worker. " "Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. " "2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset" "Default is 2 for PyTorch < 2.0.0 and otherwise None."
) )
}, },
) )
...@@ -1807,7 +1809,11 @@ class TrainingArguments: ...@@ -1807,7 +1809,11 @@ class TrainingArguments:
if self.use_cpu: if self.use_cpu:
self.dataloader_pin_memory = False self.dataloader_pin_memory = False
if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None: if (
(not is_torch_available() or is_torch_greater_or_equal_than_2_0)
and self.dataloader_num_workers == 0
and self.dataloader_prefetch_factor is not None
):
raise ValueError( raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e." "--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1." " when --dataloader_num_workers > 1."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment