Unverified Commit 3e93dd29 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Skip `TrainerIntegrationFSDP::test_basic_run_with_cpu_offload` if `torch < 2.1` (#26764)



* fix

* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 883ed4b3
......@@ -28,6 +28,7 @@ logger = logging.get_logger(__name__)
parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)
is_torch_greater_or_equal_than_2_1 = parsed_torch_version_base >= version.parse("2.1")
is_torch_greater_or_equal_than_2_0 = parsed_torch_version_base >= version.parse("2.0")
is_torch_greater_or_equal_than_1_12 = parsed_torch_version_base >= version.parse("1.12")
is_torch_greater_or_equal_than_1_11 = parsed_torch_version_base >= version.parse("1.11")
......
......@@ -14,6 +14,7 @@
import itertools
import os
import unittest
from functools import partial
from parameterized import parameterized
......@@ -37,6 +38,11 @@ from transformers.trainer_utils import FSDPOption, set_seed
from transformers.utils import is_accelerate_available, is_torch_bf16_gpu_available
if is_torch_available():
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_1
else:
is_torch_greater_or_equal_than_2_1 = False
# default torch.distributed port
DEFAULT_MASTER_PORT = "10999"
dtypes = ["fp16"]
......@@ -178,6 +184,7 @@ class TrainerIntegrationFSDP(TestCasePlus, TrainerIntegrationCommon):
@parameterized.expand(dtypes)
@require_torch_multi_gpu
@slow
@unittest.skipIf(not is_torch_greater_or_equal_than_2_1, reason="This test on pytorch 2.0 takes 4 hours.")
def test_basic_run_with_cpu_offload(self, dtype):
launcher = get_launcher(distributed=True, use_accelerate=False)
output_dir = self.get_auto_remove_tmp_dir()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment