Unverified Commit 9996558b authored by Volodymyr Byno's avatar Volodymyr Byno Committed by GitHub
Browse files

Neptune.ai integration (#11937)

An option that turns on neptune.ai logging
--report_to 'neptune'

Additional ENV variables:
	NEPTUNE_PROJECT
	NEPTUNE_API_TOKEN
	NEPTUNE_RUN_NAME (optional)
	NEPTUNE_STOP_TIMEOUT (optional)
parent ae6ce28f
......@@ -105,6 +105,10 @@ def is_deepspeed_available():
return importlib.util.find_spec("deepspeed") is not None
def is_neptune_available():
return importlib.util.find_spec("neptune") is not None
def hp_params(trial):
if is_optuna_available():
import optuna
......@@ -921,10 +925,80 @@ class MLflowCallback(TrainerCallback):
self._ml_flow.end_run()
class NeptuneCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that sends the logs to `Neptune <https://neptune.ai>`.
"""
def __init__(self):
assert (
is_neptune_available()
), "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
import neptune.new as neptune
self._neptune = neptune
self._initialized = False
self._log_artifacts = False
def setup(self, args, state, model):
"""
Setup the Neptune integration.
Environment:
NEPTUNE_PROJECT (:obj:`str`, `required`):
The project ID for neptune.ai account. Should be in format `workspace_name/project_name`
NEPTUNE_API_TOKEN (:obj:`str`, `required`):
API-token for neptune.ai account
NEPTUNE_CONNECTION_MODE (:obj:`str`, `optional`):
Neptune connection mode. `async` by default
NEPTUNE_RUN_NAME (:obj:`str`, `optional`):
The name of run process on Neptune dashboard
"""
if state.is_world_process_zero:
self._neptune_run = self._neptune.init(
project=os.getenv("NEPTUNE_PROJECT"),
api_token=os.getenv("NEPTUNE_API_TOKEN"),
mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
name=os.getenv("NEPTUNE_RUN_NAME", None),
)
combined_dict = args.to_dict()
if hasattr(model, "config") and model.config is not None:
model_config = model.config.to_dict()
combined_dict = {**model_config, **combined_dict}
self._neptune_run["parameters"] = combined_dict
self._initialized = True
def on_train_begin(self, args, state, control, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
def on_log(self, args, state, control, logs, model=None, **kwargs):
if not self._initialized:
self.setup(args, state, model)
if state.is_world_process_zero:
for k, v in logs.items():
self._neptune_run[k].log(v, step=state.global_step)
def __del__(self):
"""
Environment:
NEPTUNE_STOP_TIMEOUT (:obj:`int`, `optional`):
Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
run. If not set it will wait for all tracking calls to finish.
"""
try:
stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
stop_timeout = int(stop_timeout) if stop_timeout else None
self._neptune_run.stop(seconds=stop_timeout)
except AttributeError:
pass
INTEGRATION_TO_CALLBACK = {
"azure_ml": AzureMLCallback,
"comet_ml": CometCallback,
"mlflow": MLflowCallback,
"neptune": NeptuneCallback,
"tensorboard": TensorBoardCallback,
"wandb": WandbCallback,
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment