Unverified Commit cab3b868 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[ray] Fix `datasets_modules` ImportError with Ray Tune (#12749)

* Fix dynamic_modules ImportError with Ray Tune

* Nit
parent 534f6eb9
...@@ -14,12 +14,15 @@ ...@@ -14,12 +14,15 @@
""" """
Integrations with other Python libraries. Integrations with other Python libraries.
""" """
import functools
import importlib.util import importlib.util
import numbers import numbers
import os import os
import sys
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from .file_utils import is_datasets_available
from .utils import logging from .utils import logging
...@@ -246,8 +249,34 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR ...@@ -246,8 +249,34 @@ def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestR
"Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__) "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
) )
trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)
@functools.wraps(trainable)
def dynamic_modules_import_trainable(*args, **kwargs):
"""
Wrapper around ``tune.with_parameters`` to ensure datasets_modules are loaded on each Actor.
Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.
Assumes that ``_objective``, defined above, is a function.
"""
if is_datasets_available():
import datasets.load
dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
# load dynamic_modules from path
spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
datasets_modules = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = datasets_modules
spec.loader.exec_module(datasets_modules)
return trainable(*args, **kwargs)
# special attr set by tune.with_parameters
if hasattr(trainable, "__mixins__"):
dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__
analysis = ray.tune.run( analysis = ray.tune.run(
ray.tune.with_parameters(_objective, local_trainer=trainer), dynamic_modules_import_trainable,
config=trainer.hp_space(None), config=trainer.hp_space(None),
num_samples=n_trials, num_samples=n_trials,
**kwargs, **kwargs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment