Unverified Commit 4469010c authored by David del Río Medina's avatar David del Río Medina Committed by GitHub
Browse files

Replace assertions with RuntimeError exceptions (#14186)

parent ba71f1b5
......@@ -382,9 +382,10 @@ class TensorBoardCallback(TrainerCallback):
def __init__(self, tb_writer=None):
has_tensorboard = is_tensorboard_available()
assert (
has_tensorboard
), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
if not has_tensorboard:
raise RuntimeError(
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
)
if has_tensorboard:
try:
from torch.utils.tensorboard import SummaryWriter # noqa: F401
......@@ -465,7 +466,8 @@ class WandbCallback(TrainerCallback):
def __init__(self):
has_wandb = is_wandb_available()
assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
if not has_wandb:
raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
if has_wandb:
import wandb
......@@ -587,7 +589,8 @@ class CometCallback(TrainerCallback):
"""
def __init__(self):
assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
if not _has_comet:
raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
self._initialized = False
def setup(self, args, state, model):
......@@ -643,9 +646,8 @@ class AzureMLCallback(TrainerCallback):
"""
def __init__(self, azureml_run=None):
assert (
is_azureml_available()
), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
if not is_azureml_available():
raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
self.azureml_run = azureml_run
def on_init_end(self, args, state, control, **kwargs):
......@@ -667,7 +669,8 @@ class MLflowCallback(TrainerCallback):
"""
def __init__(self):
assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
if not is_mlflow_available():
raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
import mlflow
self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
......@@ -753,9 +756,10 @@ class NeptuneCallback(TrainerCallback):
"""
def __init__(self):
assert (
is_neptune_available()
), "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
if not is_neptune_available():
raise ValueError(
"NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
)
import neptune.new as neptune
self._neptune = neptune
......@@ -823,9 +827,10 @@ class CodeCarbonCallback(TrainerCallback):
"""
def __init__(self):
assert (
is_codecarbon_available()
), "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
if not is_codecarbon_available():
raise RuntimeError(
"CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
)
import codecarbon
self._codecarbon = codecarbon
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment