Unverified Commit 457d4a32 authored by Bram Vanroy's avatar Bram Vanroy Committed by GitHub
Browse files

Add Ray's scope to training arguments (#17629)



* allow scope from trainer arg

* add ray_scope to training args

* escape double quotes

* make style && quality

* attempt to solve doc style issues

* splitting up URLs for style

* make fixup

* Update src/transformers/training_args.py
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>

* make style
Co-authored-by: default avatarAntoni Baum <antoni.baum@protonmail.com>
parent 54833886
......@@ -297,7 +297,7 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
num_samples=n_trials,
**kwargs,
)
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=trainer.args.ray_scope)
best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
if _tb_writer is not None:
trainer.add_callback(_tb_writer)
......
......@@ -456,6 +456,12 @@ class TrainingArguments:
torchdynamo (`str`, *optional*):
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
"nvfuser]. This is an experimental API and subject to change.
ray_scope (`str`, *optional*, defaults to `"last"`):
The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray will
then use the last checkpoint of all trials, compare those, and select the best one. However, other options
are also available. See the [Ray documentation](
https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
more options.
"""
output_dir: str = field(
......@@ -916,6 +922,19 @@ class TrainingArguments:
"choices": ["eager", "nvfuser"],
},
)
ray_scope: Optional[str] = field(
default="last",
metadata={
"help": (
'The scope to use when doing hyperparameter search with Ray. By default, `"last"` will be used. Ray'
" will then use the last checkpoint of all trials, compare those, and select the best one. However,"
" other options are also available. See the Ray documentation"
" (https://docs.ray.io/en/latest/tune/api_docs/analysis.html"
"#ray.tune.ExperimentAnalysis.get_best_trial)"
" for more options."
)
},
)
def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment