Unverified Commit 3ec7a476 authored by Kshiteej K's avatar Kshiteej K Committed by GitHub
Browse files

[neptune] fix checkpoint bug with relative out_dir (#22102)



* [neptune] fix checkpoint bug with relative out_dir

* update imports

* reformat with black

* check neptune without imports

* fix typing-related issue

* run black on code

* use os.path.sep instead of raw \

* simplify imports and remove type annotation

* make ruff happy

* apply review suggestions

---------
Co-authored-by: default avatarAleksander Wojnarowicz <alwojnarowicz@gmail.com>
parent 19ade242
...@@ -31,6 +31,7 @@ import numpy as np ...@@ -31,6 +31,7 @@ import numpy as np
from . import __version__ as version from . import __version__ as version
from .utils import flatten_dict, is_datasets_available, is_torch_available, logging from .utils import flatten_dict, is_datasets_available, is_torch_available, logging
from .utils.versions import importlib_metadata
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
...@@ -53,9 +54,19 @@ if _has_comet: ...@@ -53,9 +54,19 @@ if _has_comet:
except (ImportError, ValueError): except (ImportError, ValueError):
_has_comet = False _has_comet = False
_has_neptune = importlib.util.find_spec("neptune") is not None _has_neptune = (
importlib.util.find_spec("neptune") is not None or importlib.util.find_spec("neptune-client") is not None
)
if TYPE_CHECKING and _has_neptune: if TYPE_CHECKING and _has_neptune:
from neptune.new.metadata_containers.run import Run try:
_neptune_version = importlib_metadata.version("neptune")
logger.info(f"Neptune version {_neptune_version} available.")
except importlib_metadata.PackageNotFoundError:
try:
_neptune_version = importlib_metadata.version("neptune-client")
logger.info(f"Neptune-client version {_neptune_version} available.")
except importlib_metadata.PackageNotFoundError:
_has_neptune = False
from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402 from .trainer_callback import ProgressCallback, TrainerCallback # noqa: E402
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402 from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy # noqa: E402
...@@ -1155,7 +1166,7 @@ class NeptuneCallback(TrainerCallback): ...@@ -1155,7 +1166,7 @@ class NeptuneCallback(TrainerCallback):
project: Optional[str] = None, project: Optional[str] = None,
name: Optional[str] = None, name: Optional[str] = None,
base_namespace: str = "finetuning", base_namespace: str = "finetuning",
run: Optional["Run"] = None, run=None,
log_parameters: bool = True, log_parameters: bool = True,
log_checkpoints: Optional[str] = None, log_checkpoints: Optional[str] = None,
**neptune_run_kwargs, **neptune_run_kwargs,
...@@ -1163,15 +1174,15 @@ class NeptuneCallback(TrainerCallback): ...@@ -1163,15 +1174,15 @@ class NeptuneCallback(TrainerCallback):
if not is_neptune_available(): if not is_neptune_available():
raise ValueError( raise ValueError(
"NeptuneCallback requires the Neptune client library to be installed. " "NeptuneCallback requires the Neptune client library to be installed. "
"To install the library, run `pip install neptune-client`." "To install the library, run `pip install neptune`."
) )
from neptune.new.metadata_containers.run import Run
try: try:
from neptune.new.integrations.utils import verify_type from neptune import Run
from neptune.internal.utils import verify_type
except ImportError: except ImportError:
from neptune.new.internal.utils import verify_type from neptune.new.internal.utils import verify_type
from neptune.new.metadata_containers.run import Run
verify_type("api_token", api_token, (str, type(None))) verify_type("api_token", api_token, (str, type(None)))
verify_type("project", project, (str, type(None))) verify_type("project", project, (str, type(None)))
...@@ -1288,7 +1299,10 @@ class NeptuneCallback(TrainerCallback): ...@@ -1288,7 +1299,10 @@ class NeptuneCallback(TrainerCallback):
if self._volatile_checkpoints_dir is not None: if self._volatile_checkpoints_dir is not None:
consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint) consistent_checkpoint_path = os.path.join(self._volatile_checkpoints_dir, checkpoint)
try: try:
shutil.copytree(relative_path, os.path.join(consistent_checkpoint_path, relative_path)) # Remove leading ../ from a relative path.
cpkt_path = relative_path.replace("..", "").lstrip(os.path.sep)
copy_path = os.path.join(consistent_checkpoint_path, cpkt_path)
shutil.copytree(relative_path, copy_path)
target_path = consistent_checkpoint_path target_path = consistent_checkpoint_path
except IOError as e: except IOError as e:
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment