integrations.py 39.3 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
17
import functools
18
import importlib.util
19
import numbers
20
import os
21
import sys
22
23
import tempfile
from pathlib import Path
24

25
from .utils import is_datasets_available, logging
26
27
28
29


logger = logging.get_logger(__name__)

Sylvain Gugger's avatar
Sylvain Gugger committed
30

31
# comet_ml requires to be imported before any ML frameworks
32
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
33
if _has_comet:
34
    try:
35
        import comet_ml  # noqa: F401
36

37
38
39
40
41
42
43
44
        if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
            _has_comet = True
        else:
            if os.getenv("COMET_MODE", "").upper() != "DISABLED":
                logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
            _has_comet = False
    except (ImportError, ValueError):
        _has_comet = False
45

46
from .trainer_callback import ProgressCallback, TrainerCallback  # noqa: E402
47
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy  # noqa: E402
48
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available  # noqa: E402
49
50


Sylvain Gugger's avatar
Sylvain Gugger committed
51
# Integration functions:
52
def is_wandb_available():
53
    # any value of WANDB_DISABLED disables wandb
54
    if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
55
        logger.warning(
56
57
58
            "Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
            "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
        )
59
60
        return False
    return importlib.util.find_spec("wandb") is not None
61
62
63
64
65
66
67


def is_comet_available():
    return _has_comet


def is_tensorboard_available():
68
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
69
70
71


def is_optuna_available():
72
    return importlib.util.find_spec("optuna") is not None
73
74
75


def is_ray_available():
76
    return importlib.util.find_spec("ray") is not None
77
78


79
def is_ray_tune_available():
80
81
82
    if not is_ray_available():
        return False
    return importlib.util.find_spec("ray.tune") is not None
83
84


85
86
87
88
def is_sigopt_available():
    return importlib.util.find_spec("sigopt") is not None


89
def is_azureml_available():
90
91
92
93
94
    if importlib.util.find_spec("azureml") is None:
        return False
    if importlib.util.find_spec("azureml.core") is None:
        return False
    return importlib.util.find_spec("azureml.core.run") is not None
95
96


97
def is_mlflow_available():
98
99
    if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
        return False
100
    return importlib.util.find_spec("mlflow") is not None
101
102


103
def is_fairscale_available():
104
    return importlib.util.find_spec("fairscale") is not None
105
106


107
108
109
110
def is_neptune_available():
    return importlib.util.find_spec("neptune") is not None


111
112
113
114
def is_codecarbon_available():
    return importlib.util.find_spec("codecarbon") is not None


115
116
def hp_params(trial):
    if is_optuna_available():
117
118
        import optuna

119
120
        if isinstance(trial, optuna.Trial):
            return trial.params
121
    if is_ray_tune_available():
122
123
124
        if isinstance(trial, dict):
            return trial

125
126
127
128
    if is_sigopt_available():
        if isinstance(trial, dict):
            return trial

129
130
131
132
    if is_wandb_available():
        if isinstance(trial, dict):
            return trial

133
134
135
    raise RuntimeError(f"Unknown type for trial {trial.__class__}")


136
137
138
def default_hp_search_backend():
    if is_optuna_available():
        return "optuna"
139
    elif is_ray_tune_available():
140
        return "ray"
141
142
    elif is_sigopt_available():
        return "sigopt"
143
144


145
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
146
147
    import optuna

148
    def _objective(trial, checkpoint_dir=None):
149
        checkpoint = None
150
151
152
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
153
                    checkpoint = os.path.join(checkpoint_dir, subdir)
154
        trainer.objective = None
155
        trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
156
157
158
159
160
161
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

162
163
164
165
166
167
168
169
170
    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
171
172
    import ray

173
    def _objective(trial, local_trainer, checkpoint_dir=None):
174
175
176
177
178
179
180
181
        try:
            from transformers.utils.notebook import NotebookProgressCallback

            if local_trainer.pop_callback(NotebookProgressCallback):
                local_trainer.add_callback(ProgressCallback)
        except ModuleNotFoundError:
            pass

182
        checkpoint = None
183
184
185
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
186
                    checkpoint = os.path.join(checkpoint_dir, subdir)
187
        local_trainer.objective = None
188
        local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
189
        # If there hasn't been any evaluation during the training loop.
190
191
192
193
194
        if getattr(local_trainer, "objective", None) is None:
            metrics = local_trainer.evaluate()
            local_trainer.objective = local_trainer.compute_objective(metrics)
            local_trainer._tune_save_checkpoint()
            ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
195

196
197
198
199
200
201
202
203
204
205
    if not trainer._memory_tracker.skip_memory_metrics:
        from .trainer_utils import TrainerMemoryTracker

        logger.warning(
            "Memory tracking for your Trainer is currently "
            "enabled. Automatically disabling the memory tracker "
            "since the memory tracker is not serializable."
        )
        trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)

206
207
    # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
    # while doing the ray hp search.
Sylvain Gugger's avatar
Sylvain Gugger committed
208
    _tb_writer = trainer.pop_callback(TensorBoardCallback)
209
    trainer.model = None
210

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    # Setup default `resources_per_trial`.
    if "resources_per_trial" not in kwargs:
        # Default to 1 CPU and 1 GPU (if applicable) per trial.
        kwargs["resources_per_trial"] = {"cpu": 1}
        if trainer.args.n_gpu > 0:
            kwargs["resources_per_trial"]["gpu"] = 1
        resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
        logger.info(
            "No `resources_per_trial` arg was passed into "
            "`hyperparameter_search`. Setting it to a default value "
            f"of {resource_msg} for each trial."
        )
    # Make sure each trainer only uses GPUs that were allocated per trial.
    gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
    trainer.args._n_gpu = gpus_per_trial
226

227
    # Setup default `progress_reporter`.
228
    if "progress_reporter" not in kwargs:
229
230
231
232
233
234
235
236
        from ray.tune import CLIReporter

        kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
    if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
        # `keep_checkpoints_num=0` would disabled checkpointing
        trainer.use_tune_checkpoints = True
        if kwargs["keep_checkpoints_num"] > 1:
            logger.warning(
237
                f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
238
                "Checkpoints are usually huge, "
239
                "consider setting `keep_checkpoints_num=1`."
240
            )
241
242
    if "scheduler" in kwargs:
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
243

244
245
246
247
248
249
250
251
        # Check if checkpointing is enabled for PopulationBasedTraining
        if isinstance(kwargs["scheduler"], PopulationBasedTraining):
            if not trainer.use_tune_checkpoints:
                logger.warning(
                    "You are using PopulationBasedTraining but you haven't enabled checkpointing. "
                    "This means your trials will train from scratch everytime they are exploiting "
                    "new configurations. Consider enabling checkpointing by passing "
                    "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
252
253
                )

254
255
256
        # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
        if isinstance(
            kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
257
        ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
258
259
260
261
262
            raise RuntimeError(
                "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
                "This means your trials will not report intermediate results to Ray Tune, and "
                "can thus not be stopped early or used to exploit other trials parameters. "
                "If this is what you want, do not use {cls}. If you would like to use {cls}, "
263
                "make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
264
265
                "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
            )
266

267
268
269
270
271
    trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)

    @functools.wraps(trainable)
    def dynamic_modules_import_trainable(*args, **kwargs):
        """
272
        Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
273
274
275

        Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.

276
        Assumes that `_objective`, defined above, is a function.
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
        """
        if is_datasets_available():
            import datasets.load

            dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
            # load dynamic_modules from path
            spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
            datasets_modules = importlib.util.module_from_spec(spec)
            sys.modules[spec.name] = datasets_modules
            spec.loader.exec_module(datasets_modules)
        return trainable(*args, **kwargs)

    # special attr set by tune.with_parameters
    if hasattr(trainable, "__mixins__"):
        dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__

293
    analysis = ray.tune.run(
294
        dynamic_modules_import_trainable,
295
296
297
298
        config=trainer.hp_space(None),
        num_samples=n_trials,
        **kwargs,
    )
299
300
    best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
    best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
Sylvain Gugger's avatar
Sylvain Gugger committed
301
302
    if _tb_writer is not None:
        trainer.add_callback(_tb_writer)
303
    return best_run
Sylvain Gugger's avatar
Sylvain Gugger committed
304
305


306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:

    from sigopt import Connection

    conn = Connection()
    proxies = kwargs.pop("proxies", None)
    if proxies is not None:
        conn.set_proxies(proxies)

    experiment = conn.experiments().create(
        name="huggingface-tune",
        parameters=trainer.hp_space(None),
        metrics=[dict(name="objective", objective=direction, strategy="optimize")],
        parallel_bandwidth=1,
        observation_budget=n_trials,
        project="huggingface",
    )
    logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")

    while experiment.progress.observation_count < experiment.observation_budget:
        suggestion = conn.experiments(experiment.id).suggestions().create()
        trainer.objective = None
        trainer.train(resume_from_checkpoint=None, trial=suggestion)
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)

        values = [dict(name="objective", value=trainer.objective)]
        obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
        logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
        experiment = conn.experiments(experiment.id).fetch()

    best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
    best_run = BestRun(best.id, best.value, best.assignments)

    return best_run


345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
    from .integrations import is_wandb_available

    if not is_wandb_available():
        raise ImportError("This function needs wandb installed: `pip install wandb`")
    import wandb

    # add WandbCallback if not already added in trainer callbacks
    reporting_to_wandb = False
    for callback in trainer.callback_handler.callbacks:
        if isinstance(callback, WandbCallback):
            reporting_to_wandb = True
            break
    if not reporting_to_wandb:
        trainer.add_callback(WandbCallback())
    trainer.args.report_to = "wandb"
    best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
    sweep_id = kwargs.pop("sweep_id", None)
    project = kwargs.pop("project", None)
    name = kwargs.pop("name", None)
    entity = kwargs.pop("entity", None)
    metric = kwargs.pop("metric", "eval/loss")

    sweep_config = trainer.hp_space(None)
    sweep_config["metric"]["goal"] = direction
    sweep_config["metric"]["name"] = metric
    if name:
        sweep_config["name"] = name

    def _objective():

        run = wandb.run if wandb.run else wandb.init()
        trainer.state.trial_name = run.name
        run.config.update({"assignments": {}, "metric": metric})
        config = wandb.config

        trainer.objective = None

        trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
            format_metrics = rewrite_logs(metrics)
            if metric not in format_metrics:
                logger.warning(
                    f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available metrics are {format_metrics.keys()}"
                )
        best_score = False
        if best_trial["run_id"] is not None:
            if direction == "minimize":
                best_score = trainer.objective < best_trial["objective"]
            elif direction == "maximize":
                best_score = trainer.objective > best_trial["objective"]

        if best_score or best_trial["run_id"] is None:
            best_trial["run_id"] = run.id
            best_trial["objective"] = trainer.objective
            best_trial["hyperparameters"] = dict(config)

        return trainer.objective

    sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
    logger.info(f"wandb sweep id - {sweep_id}")
    wandb.agent(sweep_id, function=_objective, count=n_trials)

    return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])


414
415
416
417
418
419
420
421
422
423
424
425
def get_available_reporting_integrations():
    integrations = []
    if is_azureml_available():
        integrations.append("azure_ml")
    if is_comet_available():
        integrations.append("comet_ml")
    if is_mlflow_available():
        integrations.append("mlflow")
    if is_tensorboard_available():
        integrations.append("tensorboard")
    if is_wandb_available():
        integrations.append("wandb")
426
427
    if is_codecarbon_available():
        integrations.append("codecarbon")
428
429
430
    return integrations


431
432
433
434
def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
435
436
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
437
438
439
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
440
441
        elif k.startswith(test_prefix):
            new_d["test/" + k[test_prefix_len:]] = v
442
443
444
445
446
        else:
            new_d["train/" + k] = v
    return new_d


Sylvain Gugger's avatar
Sylvain Gugger committed
447
448
class TensorBoardCallback(TrainerCallback):
    """
449
    A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
Sylvain Gugger's avatar
Sylvain Gugger committed
450
451

    Args:
452
        tb_writer (`SummaryWriter`, *optional*):
Tiger's avatar
Tiger committed
453
            The writer to use. Will instantiate one if not set.
Sylvain Gugger's avatar
Sylvain Gugger committed
454
455
456
    """

    def __init__(self, tb_writer=None):
457
        has_tensorboard = is_tensorboard_available()
458
459
460
461
        if not has_tensorboard:
            raise RuntimeError(
                "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
            )
462
463
464
465
466
467
468
469
470
471
472
473
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
max yue's avatar
max yue committed
474
475
        else:
            self._SummaryWriter = None
Sylvain Gugger's avatar
Sylvain Gugger committed
476
477
        self.tb_writer = tb_writer

478
479
    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
480
481
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
482
483

    def on_train_begin(self, args, state, control, **kwargs):
484
485
486
487
488
489
490
491
492
493
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

494
495
        if self.tb_writer is None:
            self._init_summary_writer(args, log_dir)
496

Sylvain Gugger's avatar
Sylvain Gugger committed
497
498
        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
499
500
501
502
503
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
504
505
506
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
Sylvain Gugger's avatar
Sylvain Gugger committed
507
508

    def on_log(self, args, state, control, logs=None, **kwargs):
509
510
511
512
513
        if not state.is_world_process_zero:
            return

        if self.tb_writer is None:
            self._init_summary_writer(args)
514

515
        if self.tb_writer is not None:
516
            logs = rewrite_logs(logs)
Sylvain Gugger's avatar
Sylvain Gugger committed
517
518
519
520
521
522
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
523
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
Sylvain Gugger's avatar
Sylvain Gugger committed
524
                        "This invocation of Tensorboard's writer.add_scalar() "
525
                        "is incorrect so we dropped this attribute."
Sylvain Gugger's avatar
Sylvain Gugger committed
526
527
528
529
530
531
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
532
            self.tb_writer = None
Sylvain Gugger's avatar
Sylvain Gugger committed
533
534
535
536


class WandbCallback(TrainerCallback):
    """
537
    A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/).
Sylvain Gugger's avatar
Sylvain Gugger committed
538
539
540
    """

    def __init__(self):
541
        has_wandb = is_wandb_available()
542
543
        if not has_wandb:
            raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
544
545
546
        if has_wandb:
            import wandb

547
            self._wandb = wandb
Sylvain Gugger's avatar
Sylvain Gugger committed
548
        self._initialized = False
549
550
        # log outputs
        self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
Sylvain Gugger's avatar
Sylvain Gugger committed
551

552
    def setup(self, args, state, model, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
553
        """
554
        Setup the optional Weights & Biases (*wandb*) integration.
Sylvain Gugger's avatar
Sylvain Gugger committed
555

Sylvain Gugger's avatar
Sylvain Gugger committed
556
557
558
        One can subclass and override this method to customize the setup if needed. Find more information
        [here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment
        variables:
Sylvain Gugger's avatar
Sylvain Gugger committed
559
560

        Environment:
561
            WANDB_LOG_MODEL (`bool`, *optional*, defaults to `False`):
562
                Whether or not to log model as artifact at the end of training. Use along with
563
564
                *TrainingArguments.load_best_model_at_end* to upload best model.
            WANDB_WATCH (`str`, *optional* defaults to `"gradients"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
565
566
                Can be `"gradients"`, `"all"` or `"false"`. Set to `"false"` to disable gradient logging or `"all"` to
                log gradients and parameters.
567
            WANDB_PROJECT (`str`, *optional*, defaults to `"huggingface"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
568
                Set this to a custom string to store results in a different project.
569
570
            WANDB_DISABLED (`bool`, *optional*, defaults to `False`):
                Whether or not to disable wandb entirely. Set *WANDB_DISABLED=true* to disable.
Sylvain Gugger's avatar
Sylvain Gugger committed
571
        """
572
573
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
574
575
576
577
578
579
        self._initialized = True
        if state.is_world_process_zero:
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**args.to_sanitized_dict()}
580
581
582
583
584
585
586
587
588
589
590
591

            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            trial_name = state.trial_name
            init_args = {}
            if trial_name is not None:
                run_name = trial_name
                init_args["group"] = args.run_name
            else:
                run_name = args.run_name

592
593
594
595
596
597
598
599
600
601
602
603
604
            if self._wandb.run is None:
                self._wandb.init(
                    project=os.getenv("WANDB_PROJECT", "huggingface"),
                    name=run_name,
                    **init_args,
                )
            # add config parameters (run may have been created manually)
            self._wandb.config.update(combined_dict, allow_val_change=True)

            # define default x-axis (for latest wandb versions)
            if getattr(self._wandb, "define_metric", None):
                self._wandb.define_metric("train/global_step")
                self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
605

Sylvain Gugger's avatar
Sylvain Gugger committed
606
607
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
608
609
610
                self._wandb.watch(
                    model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
611
612

    def on_train_begin(self, args, state, control, model=None, **kwargs):
613
614
        if self._wandb is None:
            return
615
        hp_search = state.is_hyper_param_search
616
617
        if hp_search:
            self._wandb.finish()
618
            self._initialized = False
619
            args.run_name = None
620
621
        if not self._initialized:
            self.setup(args, state, model, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
622

623
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
624
625
        if self._wandb is None:
            return
626
627
628
629
630
631
632
633
634
        if self._log_model and self._initialized and state.is_world_process_zero:
            from .trainer import Trainer

            fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
            with tempfile.TemporaryDirectory() as temp_dir:
                fake_trainer.save_model(temp_dir)
                metadata = (
                    {
                        k: v
635
                        for k, v in dict(self._wandb.summary).items()
636
637
638
639
640
641
642
643
                        if isinstance(v, numbers.Number) and not k.startswith("_")
                    }
                    if not args.load_best_model_at_end
                    else {
                        f"eval/{args.metric_for_best_model}": state.best_metric,
                        "train/total_floss": state.total_flos,
                    }
                )
644
                artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
645
646
647
648
                for f in Path(temp_dir).glob("*"):
                    if f.is_file():
                        with artifact.new_file(f.name, mode="wb") as fa:
                            fa.write(f.read_bytes())
649
                self._wandb.run.log_artifact(artifact)
650

Sylvain Gugger's avatar
Sylvain Gugger committed
651
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
652
653
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
654
        if not self._initialized:
655
            self.setup(args, state, model)
Sylvain Gugger's avatar
Sylvain Gugger committed
656
        if state.is_world_process_zero:
657
            logs = rewrite_logs(logs)
658
            self._wandb.log({**logs, "train/global_step": state.global_step})
Sylvain Gugger's avatar
Sylvain Gugger committed
659
660
661
662


class CometCallback(TrainerCallback):
    """
663
    A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
Sylvain Gugger's avatar
Sylvain Gugger committed
664
665
666
    """

    def __init__(self):
667
668
        if not _has_comet:
            raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
Sylvain Gugger's avatar
Sylvain Gugger committed
669
        self._initialized = False
670
        self._log_assets = False
Sylvain Gugger's avatar
Sylvain Gugger committed
671
672
673
674
675
676

    def setup(self, args, state, model):
        """
        Setup the optional Comet.ml integration.

        Environment:
677
            COMET_MODE (`str`, *optional*):
678
679
                Whether to create an online, offline experiment or disable Comet logging. Can be "OFFLINE", "ONLINE",
                or "DISABLED". Defaults to "ONLINE".
680
            COMET_PROJECT_NAME (`str`, *optional*):
681
                Comet project name for experiments
682
683
684
            COMET_OFFLINE_DIRECTORY (`str`, *optional*):
                Folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
            COMET_LOG_ASSETS (`str`, *optional*):
685
686
                Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be "TRUE", or
                "FALSE". Defaults to "TRUE".
Sylvain Gugger's avatar
Sylvain Gugger committed
687

Sylvain Gugger's avatar
Sylvain Gugger committed
688
689
        For a number of configurable items in the environment, see
        [here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
Sylvain Gugger's avatar
Sylvain Gugger committed
690
691
        """
        self._initialized = True
692
693
694
        log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
        if log_assets in {"TRUE", "1"}:
            self._log_assets = True
Sylvain Gugger's avatar
Sylvain Gugger committed
695
696
697
        if state.is_world_process_zero:
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            experiment = None
698
            experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
Sylvain Gugger's avatar
Sylvain Gugger committed
699
            if comet_mode == "ONLINE":
700
701
                experiment = comet_ml.Experiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
702
703
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
704
705
706
                experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(model, framework="transformers")
                experiment._log_parameters(args, prefix="args/", framework="transformers")
                if hasattr(model, "config"):
                    experiment._log_parameters(model.config, prefix="config/", framework="transformers")

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if experiment is not None:
                experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
725

726
727
728
729
730
731
732
733
734
735
    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if (experiment is not None) and (self._log_assets is True):
                logger.info("Logging checkpoints. This may take time.")
                experiment.log_asset_folder(
                    args.output_dir, recursive=True, log_file_name=True, step=state.global_step
                )
            experiment.end()

736

737
738
class AzureMLCallback(TrainerCallback):
    """
739
    A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
740
741
742
    """

    def __init__(self, azureml_run=None):
743
744
        if not is_azureml_available():
            raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
745
746
747
        self.azureml_run = azureml_run

    def on_init_end(self, args, state, control, **kwargs):
748
749
        from azureml.core.run import Run

750
751
752
753
        if self.azureml_run is None and state.is_world_process_zero:
            self.azureml_run = Run.get_context()

    def on_log(self, args, state, control, logs=None, **kwargs):
754
        if self.azureml_run and state.is_world_process_zero:
755
756
757
758
759
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.azureml_run.log(k, v, description=k)


760
761
class MLflowCallback(TrainerCallback):
    """
762
763
    A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
    environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
764
765
766
    """

    def __init__(self):
767
768
        if not is_mlflow_available():
            raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
769
770
        import mlflow

771
772
773
        self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
        self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH

774
775
        self._initialized = False
        self._log_artifacts = False
776
        self._ml_flow = mlflow
777
778
779
780
781
782

    def setup(self, args, state, model):
        """
        Setup the optional MLflow integration.

        Environment:
783
            HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
784
785
                Whether to use MLflow .log_artifact() facility to log artifacts.

786
                This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy
Sylvain Gugger's avatar
Sylvain Gugger committed
787
788
                whatever is in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it
                without a remote storage will just copy the files to your artifact location.
789
790
791
792
793
        """
        log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
        if log_artifacts in {"TRUE", "1"}:
            self._log_artifacts = True
        if state.is_world_process_zero:
794
795
            if self._ml_flow.active_run is None:
                self._ml_flow.start_run(run_name=args.run_name)
796
797
798
799
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
800
801
802
803
804
805
806
807
808
809
810
            # remove params that are too long for MLflow
            for name, value in list(combined_dict.items()):
                # internally, all values are converted to str in MLflow
                if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
                    logger.warning(
                        f"Trainer is attempting to log a value of "
                        f'"{value}" for key "{name}" as a parameter. '
                        f"MLflow's log_param() only accepts values no longer than "
                        f"250 characters so we dropped this attribute."
                    )
                    del combined_dict[name]
811
812
            # MLflow cannot log more than 100 values in one go, so we have to split it
            combined_dict_items = list(combined_dict.items())
813
814
            for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
                self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
815
816
817
818
819
820
821
822
823
824
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
825
            metrics = {}
826
827
            for k, v in logs.items():
                if isinstance(v, (int, float)):
828
                    metrics[k] = v
829
830
                else:
                    logger.warning(
831
832
833
834
                        f"Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a metric. '
                        f"MLflow's log_metric() only accepts float and "
                        f"int types so we dropped this attribute."
835
                    )
836
            self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
837
838
839
840
841

    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            if self._log_artifacts:
                logger.info("Logging artifacts. This may take time.")
842
                self._ml_flow.log_artifacts(args.output_dir)
843
844
845
846

    def __del__(self):
        # if the previous run is not terminated correctly, the fluent API will
        # not let you start a new run before the previous one is killed
847
        if self._ml_flow.active_run is not None:
848
            self._ml_flow.end_run()
849
850


851
852
class NeptuneCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
853
    A [`TrainerCallback`] that sends the logs to [Neptune](https://neptune.ai).
854
855
856
    """

    def __init__(self):
857
858
859
860
        if not is_neptune_available():
            raise ValueError(
                "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
            )
861
862
863
864
865
866
867
868
869
870
871
        import neptune.new as neptune

        self._neptune = neptune
        self._initialized = False
        self._log_artifacts = False

    def setup(self, args, state, model):
        """
        Setup the Neptune integration.

        Environment:
872
873
874
            NEPTUNE_PROJECT (`str`, *required*):
                The project ID for neptune.ai account. Should be in format *workspace_name/project_name*
            NEPTUNE_API_TOKEN (`str`, *required*):
875
                API-token for neptune.ai account
876
877
878
            NEPTUNE_CONNECTION_MODE (`str`, *optional*):
                Neptune connection mode. *async* by default
            NEPTUNE_RUN_NAME (`str`, *optional*):
879
880
881
882
883
884
885
886
                The name of run process on Neptune dashboard
        """
        if state.is_world_process_zero:
            self._neptune_run = self._neptune.init(
                project=os.getenv("NEPTUNE_PROJECT"),
                api_token=os.getenv("NEPTUNE_API_TOKEN"),
                mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
                name=os.getenv("NEPTUNE_RUN_NAME", None),
887
                run=os.getenv("NEPTUNE_RUN_ID", None),
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
            )
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            self._neptune_run["parameters"] = combined_dict
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            for k, v in logs.items():
                self._neptune_run[k].log(v, step=state.global_step)

    def __del__(self):
        """
        Environment:
910
            NEPTUNE_STOP_TIMEOUT (`int`, *optional*):
911
912
913
914
915
916
917
918
919
920
921
                Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
                run. If not set it will wait for all tracking calls to finish.
        """
        try:
            stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
            stop_timeout = int(stop_timeout) if stop_timeout else None
            self._neptune_run.stop(seconds=stop_timeout)
        except AttributeError:
            pass


922
923
class CodeCarbonCallback(TrainerCallback):
    """
924
    A [`TrainerCallback`] that tracks the CO2 emission of training.
925
926
927
    """

    def __init__(self):
928
929
930
931
        if not is_codecarbon_available():
            raise RuntimeError(
                "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
            )
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
        import codecarbon

        self._codecarbon = codecarbon
        self.tracker = None

    def on_init_end(self, args, state, control, **kwargs):
        if self.tracker is None and state.is_local_process_zero:
            # CodeCarbon will automatically handle environment variables for configuration
            self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.start()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.stop()


951
952
953
954
INTEGRATION_TO_CALLBACK = {
    "azure_ml": AzureMLCallback,
    "comet_ml": CometCallback,
    "mlflow": MLflowCallback,
955
    "neptune": NeptuneCallback,
956
957
    "tensorboard": TensorBoardCallback,
    "wandb": WandbCallback,
958
    "codecarbon": CodeCarbonCallback,
959
960
961
962
963
964
965
966
967
968
}


def get_reporting_integration_callbacks(report_to):
    for integration in report_to:
        if integration not in INTEGRATION_TO_CALLBACK:
            raise ValueError(
                f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
            )
    return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]