Unverified Commit b9d66f4c authored by Amog Kamsetty's avatar Amog Kamsetty Committed by GitHub
Browse files

Ray Tune Integration Updates (#12134)

* fix

* fixes

* add back to scheduled tests

* formatting

* Update integrations.py
parent a79585bb
...@@ -33,7 +33,7 @@ jobs: ...@@ -33,7 +33,7 @@ jobs:
run: | run: |
apt -y update && apt install -y libsndfile1-dev apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm] pip install .[integrations, sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
- name: Are GPUs recognized by our DL frameworks - name: Are GPUs recognized by our DL frameworks
run: | run: |
...@@ -155,7 +155,7 @@ jobs: ...@@ -155,7 +155,7 @@ jobs:
run: | run: |
apt -y update && apt install -y libsndfile1-dev apt -y update && apt install -y libsndfile1-dev
pip install --upgrade pip pip install --upgrade pip
pip install .[sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm] pip install .[integrations, sklearn,testing,onnxruntime,sentencepiece,speech,vision,timm]
- name: Are GPUs recognized by our DL frameworks - name: Are GPUs recognized by our DL frameworks
run: | run: |
......
...@@ -163,11 +163,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -163,11 +163,21 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
local_trainer._tune_save_checkpoint() local_trainer._tune_save_checkpoint()
ray.tune.report(objective=local_trainer.objective, **metrics, done=True) ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
if not trainer._memory_tracker.skip_memory_metrics:
from .trainer_utils import TrainerMemoryTracker
logger.warning(
"Memory tracking for your Trainer is currently "
"enabled. Automatically disabling the memory tracker "
"since the memory tracker is not serializable."
)
trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)
# The model and TensorBoard writer do not pickle so we have to remove them (if they exists) # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
# while doing the ray hp search. # while doing the ray hp search.
_tb_writer = trainer.pop_callback(TensorBoardCallback) _tb_writer = trainer.pop_callback(TensorBoardCallback)
trainer.model = None trainer.model = None
# Setup default `resources_per_trial`. # Setup default `resources_per_trial`.
if "resources_per_trial" not in kwargs: if "resources_per_trial" not in kwargs:
# Default to 1 CPU and 1 GPU (if applicable) per trial. # Default to 1 CPU and 1 GPU (if applicable) per trial.
...@@ -194,7 +204,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -194,7 +204,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
trainer.use_tune_checkpoints = True trainer.use_tune_checkpoints = True
if kwargs["keep_checkpoints_num"] > 1: if kwargs["keep_checkpoints_num"] > 1:
logger.warning( logger.warning(
f"Currently keeping {kwargs['keep_checkpoint_num']} checkpoints for each trial. " f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
"Checkpoints are usually huge, " "Checkpoints are usually huge, "
"consider setting `keep_checkpoints_num=1`." "consider setting `keep_checkpoints_num=1`."
) )
......
...@@ -1307,7 +1307,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase): ...@@ -1307,7 +1307,7 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
self.n_epochs = args.num_train_epochs self.n_epochs = args.num_train_epochs
self.batch_size = args.train_batch_size self.batch_size = args.train_batch_size
def test_hyperparameter_search(self): def ray_hyperparameter_search(self):
class MyTrialShortNamer(TrialShortNamer): class MyTrialShortNamer(TrialShortNamer):
DEFAULTS = {"a": 0, "b": 0} DEFAULTS = {"a": 0, "b": 0}
...@@ -1320,7 +1320,13 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase): ...@@ -1320,7 +1320,13 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
} }
def model_init(config): def model_init(config):
model_config = RegressionModelConfig(a=config["a"], b=config["b"], double_output=False) if config is None:
a = 0
b = 0
else:
a = config["a"]
b = config["b"]
model_config = RegressionModelConfig(a=a, b=b, double_output=False)
return RegressionPreTrainedModel(model_config) return RegressionPreTrainedModel(model_config)
...@@ -1343,3 +1349,14 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase): ...@@ -1343,3 +1349,14 @@ class TrainerHyperParameterRayIntegrationTest(unittest.TestCase):
trainer.hyperparameter_search( trainer.hyperparameter_search(
direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4 direction="minimize", hp_space=hp_space, hp_name=hp_name, backend="ray", n_trials=4
) )
def test_hyperparameter_search(self):
self.ray_hyperparameter_search()
def test_hyperparameter_search_ray_client(self):
import ray
from ray.util.client.ray_client_helpers import ray_start_client_server
with ray_start_client_server():
assert ray.util.client.ray.is_connected()
self.ray_hyperparameter_search()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment