"vscode:/vscode.git/clone" did not exist on "83e5a10603ca902c266e40fc98a01dd8a9b04ac4"
Unverified Commit b9d66f4c authored by Amog Kamsetty's avatar Amog Kamsetty Committed by GitHub
Browse files

Ray Tune Integration Updates (#12134)

* fix

* fixes

* add back to scheduled tests

* formatting

* Update integrations.py
parent a79585bb
......@@ -33,7 +33,7 @@ jobs:
run: |
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
pip install .[integrations, sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
- name: Are GPUs recognized by our DL frameworks
run: |
......@@ -155,7 +155,7 @@ jobs:
run: |
apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
pip install .[integrations, sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
- name: Are GPUs recognized by our DL frameworks
run: |
......
......@@ -163,11 +163,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
if not trainer._memory_tracker.skip_memory_metrics:
from .trainer_utils import TrainerMemoryTracker
logger.warning(
"Memory tracking for your Trainer is currently "
"enabled. Automatically disabling the memory tracker "
"since the memory tracker is not serializable."
)
trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search.
_tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None
# Setup default `resources_per_trial`.
if "resources_per_trial" not in kwargs:
# Default to 1 CPU and 1 GPU (if applicable) per trial.
......@@ -194,7 +204,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1:
logger.warning(
f"Currently keeping {kwargs['keep_checkpoint_num']} checkpoints for each trial. "
f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
"Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`."
)
......
......@@ -1307,7 +1307,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size
def test_hyperparameter_search(self):
def ray_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0}
......@@ -1320,7 +1320,13 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
}
def model_init(config):
model_config = RegressionModelConfig(a=config["a"], b=config["b"], double_output=False)
if config is None:
a = 0
b = 0
else:
a = config["a"]
b = config["b"]
model_config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(model_config)
......@@ -1343,3 +1349,14 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4
)
def test_hyperparameter_search(self):
self.ray_hyperparameter_search()
def test_hyperparameter_search_ray_client(self):
import ray
from ray.util.client.ray_client_helpers import ray_start_client_server
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
self.ray_hyperparameter_search()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment