integrations.py 42.2 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
17
import functools
18
import importlib.util
19
import json
20
import numbers
21
import os
22
import sys
23
24
import tempfile
from pathlib import Path
25

26
from .utils import flatten_dict, is_datasets_available, logging
27
28
29
30


logger = logging.get_logger(__name__)

Sylvain Gugger's avatar
Sylvain Gugger committed
31

32
# comet_ml requires to be imported before any ML frameworks
33
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
34
if _has_comet:
35
    try:
36
        import comet_ml  # noqa: F401
37

38
39
40
41
42
43
44
45
        if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
            _has_comet = True
        else:
            if os.getenv("COMET_MODE", "").upper() != "DISABLED":
                logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
            _has_comet = False
    except (ImportError, ValueError):
        _has_comet = False
46

47
from .trainer_callback import ProgressCallback, TrainerCallback  # noqa: E402
48
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy  # noqa: E402
49
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available  # noqa: E402
50
51


Sylvain Gugger's avatar
Sylvain Gugger committed
52
# Integration functions:
53
def is_wandb_available():
54
    # any value of WANDB_DISABLED disables wandb
55
    if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
56
        logger.warning(
57
58
59
            "Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
            "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
        )
60
61
        return False
    return importlib.util.find_spec("wandb") is not None
62
63
64
65
66
67
68


def is_comet_available():
    return _has_comet


def is_tensorboard_available():
69
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
70
71
72


def is_optuna_available():
73
    return importlib.util.find_spec("optuna") is not None
74
75
76


def is_ray_available():
77
    return importlib.util.find_spec("ray") is not None
78
79


80
def is_ray_tune_available():
81
82
83
    if not is_ray_available():
        return False
    return importlib.util.find_spec("ray.tune") is not None
84
85


86
87
88
89
def is_sigopt_available():
    return importlib.util.find_spec("sigopt") is not None


90
def is_azureml_available():
91
92
93
94
95
    if importlib.util.find_spec("azureml") is None:
        return False
    if importlib.util.find_spec("azureml.core") is None:
        return False
    return importlib.util.find_spec("azureml.core.run") is not None
96
97


98
def is_mlflow_available():
99
100
    if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
        return False
101
    return importlib.util.find_spec("mlflow") is not None
102
103


104
def is_fairscale_available():
105
    return importlib.util.find_spec("fairscale") is not None
106
107


108
109
110
111
def is_neptune_available():
    return importlib.util.find_spec("neptune") is not None


112
113
114
115
def is_codecarbon_available():
    return importlib.util.find_spec("codecarbon") is not None


116
117
def hp_params(trial):
    if is_optuna_available():
118
119
        import optuna

120
121
        if isinstance(trial, optuna.Trial):
            return trial.params
122
    if is_ray_tune_available():
123
124
125
        if isinstance(trial, dict):
            return trial

126
127
128
129
    if is_sigopt_available():
        if isinstance(trial, dict):
            return trial

130
131
132
133
    if is_wandb_available():
        if isinstance(trial, dict):
            return trial

134
135
136
    raise RuntimeError(f"Unknown type for trial {trial.__class__}")


137
138
139
def default_hp_search_backend():
    if is_optuna_available():
        return "optuna"
140
    elif is_ray_tune_available():
141
        return "ray"
142
143
    elif is_sigopt_available():
        return "sigopt"
144
145


146
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
147
148
    import optuna

149
    def _objective(trial, checkpoint_dir=None):
150
        checkpoint = None
151
152
153
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
154
                    checkpoint = os.path.join(checkpoint_dir, subdir)
155
        trainer.objective = None
156
        trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
157
158
159
160
161
162
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

163
164
165
166
167
168
169
170
171
    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
172
173
    import ray

174
    def _objective(trial, local_trainer, checkpoint_dir=None):
175
176
177
178
179
180
181
182
        try:
            from transformers.utils.notebook import NotebookProgressCallback

            if local_trainer.pop_callback(NotebookProgressCallback):
                local_trainer.add_callback(ProgressCallback)
        except ModuleNotFoundError:
            pass

183
        checkpoint = None
184
185
186
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
187
                    checkpoint = os.path.join(checkpoint_dir, subdir)
188
        local_trainer.objective = None
189
        local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
190
        # If there hasn't been any evaluation during the training loop.
191
192
193
194
195
        if getattr(local_trainer, "objective", None) is None:
            metrics = local_trainer.evaluate()
            local_trainer.objective = local_trainer.compute_objective(metrics)
            local_trainer._tune_save_checkpoint()
            ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
196

197
198
199
200
201
202
203
204
205
206
    if not trainer._memory_tracker.skip_memory_metrics:
        from .trainer_utils import TrainerMemoryTracker

        logger.warning(
            "Memory tracking for your Trainer is currently "
            "enabled. Automatically disabling the memory tracker "
            "since the memory tracker is not serializable."
        )
        trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)

207
208
    # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
    # while doing the ray hp search.
Sylvain Gugger's avatar
Sylvain Gugger committed
209
    _tb_writer = trainer.pop_callback(TensorBoardCallback)
210
    trainer.model = None
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    # Setup default `resources_per_trial`.
    if "resources_per_trial" not in kwargs:
        # Default to 1 CPU and 1 GPU (if applicable) per trial.
        kwargs["resources_per_trial"] = {"cpu": 1}
        if trainer.args.n_gpu > 0:
            kwargs["resources_per_trial"]["gpu"] = 1
        resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
        logger.info(
            "No `resources_per_trial` arg was passed into "
            "`hyperparameter_search`. Setting it to a default value "
            f"of {resource_msg} for each trial."
        )
    # Make sure each trainer only uses GPUs that were allocated per trial.
    gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
    trainer.args._n_gpu = gpus_per_trial
227

228
    # Setup default `progress_reporter`.
229
    if "progress_reporter" not in kwargs:
230
231
232
233
234
235
236
237
        from ray.tune import CLIReporter

        kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
    if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
        # `keep_checkpoints_num=0` would disabled checkpointing
        trainer.use_tune_checkpoints = True
        if kwargs["keep_checkpoints_num"] > 1:
            logger.warning(
238
                f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
239
                "Checkpoints are usually huge, "
240
                "consider setting `keep_checkpoints_num=1`."
241
            )
242
243
    if "scheduler" in kwargs:
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
244

245
246
247
248
249
250
251
252
        # Check if checkpointing is enabled for PopulationBasedTraining
        if isinstance(kwargs["scheduler"], PopulationBasedTraining):
            if not trainer.use_tune_checkpoints:
                logger.warning(
                    "You are using PopulationBasedTraining but you haven't enabled checkpointing. "
                    "This means your trials will train from scratch everytime they are exploiting "
                    "new configurations. Consider enabling checkpointing by passing "
                    "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
253
254
                )

255
256
257
        # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
        if isinstance(
            kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
258
        ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
259
260
261
262
263
            raise RuntimeError(
                "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
                "This means your trials will not report intermediate results to Ray Tune, and "
                "can thus not be stopped early or used to exploit other trials parameters. "
                "If this is what you want, do not use {cls}. If you would like to use {cls}, "
264
                "make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
265
266
                "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
            )
267

268
269
270
271
272
    trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)

    @functools.wraps(trainable)
    def dynamic_modules_import_trainable(*args, **kwargs):
        """
273
        Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
274
275
276

        Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.

277
        Assumes that `_objective`, defined above, is a function.
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        """
        if is_datasets_available():
            import datasets.load

            dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
            # load dynamic_modules from path
            spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
            datasets_modules = importlib.util.module_from_spec(spec)
            sys.modules[spec.name] = datasets_modules
            spec.loader.exec_module(datasets_modules)
        return trainable(*args, **kwargs)

    # special attr set by tune.with_parameters
    if hasattr(trainable, "__mixins__"):
        dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__

294
    analysis = ray.tune.run(
295
        dynamic_modules_import_trainable,
296
297
298
299
        config=trainer.hp_space(None),
        num_samples=n_trials,
        **kwargs,
    )
300
    best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3], scope=trainer.args.ray_scope)
301
    best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
    if _tb_writer is not None:
        trainer.add_callback(_tb_writer)
304
    return best_run
Sylvain Gugger's avatar
Sylvain Gugger committed
305
306


307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:

    from sigopt import Connection

    conn = Connection()
    proxies = kwargs.pop("proxies", None)
    if proxies is not None:
        conn.set_proxies(proxies)

    experiment = conn.experiments().create(
        name="huggingface-tune",
        parameters=trainer.hp_space(None),
        metrics=[dict(name="objective", objective=direction, strategy="optimize")],
        parallel_bandwidth=1,
        observation_budget=n_trials,
        project="huggingface",
    )
    logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")

    while experiment.progress.observation_count < experiment.observation_budget:
        suggestion = conn.experiments(experiment.id).suggestions().create()
        trainer.objective = None
        trainer.train(resume_from_checkpoint=None, trial=suggestion)
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)

        values = [dict(name="objective", value=trainer.objective)]
        obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
        logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
        experiment = conn.experiments(experiment.id).fetch()

    best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
    best_run = BestRun(best.id, best.value, best.assignments)

    return best_run


346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
    from .integrations import is_wandb_available

    if not is_wandb_available():
        raise ImportError("This function needs wandb installed: `pip install wandb`")
    import wandb

    # add WandbCallback if not already added in trainer callbacks
    reporting_to_wandb = False
    for callback in trainer.callback_handler.callbacks:
        if isinstance(callback, WandbCallback):
            reporting_to_wandb = True
            break
    if not reporting_to_wandb:
        trainer.add_callback(WandbCallback())
    trainer.args.report_to = "wandb"
    best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
    sweep_id = kwargs.pop("sweep_id", None)
    project = kwargs.pop("project", None)
    name = kwargs.pop("name", None)
    entity = kwargs.pop("entity", None)
    metric = kwargs.pop("metric", "eval/loss")

    sweep_config = trainer.hp_space(None)
    sweep_config["metric"]["goal"] = direction
    sweep_config["metric"]["name"] = metric
    if name:
        sweep_config["name"] = name

    def _objective():

        run = wandb.run if wandb.run else wandb.init()
        trainer.state.trial_name = run.name
        run.config.update({"assignments": {}, "metric": metric})
        config = wandb.config

        trainer.objective = None

        trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
            format_metrics = rewrite_logs(metrics)
            if metric not in format_metrics:
                logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
392
393
                    f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available"
                    f" metrics are {format_metrics.keys()}"
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                )
        best_score = False
        if best_trial["run_id"] is not None:
            if direction == "minimize":
                best_score = trainer.objective < best_trial["objective"]
            elif direction == "maximize":
                best_score = trainer.objective > best_trial["objective"]

        if best_score or best_trial["run_id"] is None:
            best_trial["run_id"] = run.id
            best_trial["objective"] = trainer.objective
            best_trial["hyperparameters"] = dict(config)

        return trainer.objective

    sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
    logger.info(f"wandb sweep id - {sweep_id}")
    wandb.agent(sweep_id, function=_objective, count=n_trials)

    return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])


416
417
418
419
420
421
422
423
424
425
426
427
def get_available_reporting_integrations():
    integrations = []
    if is_azureml_available():
        integrations.append("azure_ml")
    if is_comet_available():
        integrations.append("comet_ml")
    if is_mlflow_available():
        integrations.append("mlflow")
    if is_tensorboard_available():
        integrations.append("tensorboard")
    if is_wandb_available():
        integrations.append("wandb")
428
429
    if is_codecarbon_available():
        integrations.append("codecarbon")
430
431
432
    return integrations


433
434
435
436
def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
437
438
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
439
440
441
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
442
443
        elif k.startswith(test_prefix):
            new_d["test/" + k[test_prefix_len:]] = v
444
445
446
447
448
        else:
            new_d["train/" + k] = v
    return new_d


Sylvain Gugger's avatar
Sylvain Gugger committed
449
450
class TensorBoardCallback(TrainerCallback):
    """
451
    A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
Sylvain Gugger's avatar
Sylvain Gugger committed
452
453

    Args:
454
        tb_writer (`SummaryWriter`, *optional*):
Tiger's avatar
Tiger committed
455
            The writer to use. Will instantiate one if not set.
Sylvain Gugger's avatar
Sylvain Gugger committed
456
457
458
    """

    def __init__(self, tb_writer=None):
459
        has_tensorboard = is_tensorboard_available()
460
461
        if not has_tensorboard:
            raise RuntimeError(
Sylvain Gugger's avatar
Sylvain Gugger committed
462
463
                "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
                " install tensorboardX."
464
            )
465
466
467
468
469
470
471
472
473
474
475
476
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
max yue's avatar
max yue committed
477
478
        else:
            self._SummaryWriter = None
Sylvain Gugger's avatar
Sylvain Gugger committed
479
480
        self.tb_writer = tb_writer

481
482
    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
483
484
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
485
486

    def on_train_begin(self, args, state, control, **kwargs):
487
488
489
490
491
492
493
494
495
496
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

497
498
        if self.tb_writer is None:
            self._init_summary_writer(args, log_dir)
499

Sylvain Gugger's avatar
Sylvain Gugger committed
500
501
        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
502
503
504
505
506
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
507
508
509
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
Sylvain Gugger's avatar
Sylvain Gugger committed
510
511

    def on_log(self, args, state, control, logs=None, **kwargs):
512
513
514
515
516
        if not state.is_world_process_zero:
            return

        if self.tb_writer is None:
            self._init_summary_writer(args)
517

518
        if self.tb_writer is not None:
519
            logs = rewrite_logs(logs)
Sylvain Gugger's avatar
Sylvain Gugger committed
520
521
522
523
524
525
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
526
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
Sylvain Gugger's avatar
Sylvain Gugger committed
527
                        "This invocation of Tensorboard's writer.add_scalar() "
528
                        "is incorrect so we dropped this attribute."
Sylvain Gugger's avatar
Sylvain Gugger committed
529
530
531
532
533
534
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
535
            self.tb_writer = None
Sylvain Gugger's avatar
Sylvain Gugger committed
536
537
538
539


class WandbCallback(TrainerCallback):
    """
540
    A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/).
Sylvain Gugger's avatar
Sylvain Gugger committed
541
542
543
    """

    def __init__(self):
544
        has_wandb = is_wandb_available()
545
546
        if not has_wandb:
            raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
547
548
549
        if has_wandb:
            import wandb

550
            self._wandb = wandb
Sylvain Gugger's avatar
Sylvain Gugger committed
551
        self._initialized = False
552
553
        # log outputs
        self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
Sylvain Gugger's avatar
Sylvain Gugger committed
554

555
    def setup(self, args, state, model, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
556
        """
557
        Setup the optional Weights & Biases (*wandb*) integration.
Sylvain Gugger's avatar
Sylvain Gugger committed
558

Sylvain Gugger's avatar
Sylvain Gugger committed
559
560
561
        One can subclass and override this method to customize the setup if needed. Find more information
        [here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment
        variables:
Sylvain Gugger's avatar
Sylvain Gugger committed
562
563

        Environment:
564
            WANDB_LOG_MODEL (`bool`, *optional*, defaults to `False`):
565
                Whether or not to log model as artifact at the end of training. Use along with
566
567
                *TrainingArguments.load_best_model_at_end* to upload best model.
            WANDB_WATCH (`str`, *optional* defaults to `"gradients"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
568
569
                Can be `"gradients"`, `"all"` or `"false"`. Set to `"false"` to disable gradient logging or `"all"` to
                log gradients and parameters.
570
            WANDB_PROJECT (`str`, *optional*, defaults to `"huggingface"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
571
                Set this to a custom string to store results in a different project.
572
573
            WANDB_DISABLED (`bool`, *optional*, defaults to `False`):
                Whether or not to disable wandb entirely. Set *WANDB_DISABLED=true* to disable.
Sylvain Gugger's avatar
Sylvain Gugger committed
574
        """
575
576
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
577
578
579
580
581
582
        self._initialized = True
        if state.is_world_process_zero:
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**args.to_sanitized_dict()}
583
584
585
586
587
588
589
590
591
592
593
594

            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            trial_name = state.trial_name
            init_args = {}
            if trial_name is not None:
                run_name = trial_name
                init_args["group"] = args.run_name
            else:
                run_name = args.run_name

595
596
597
598
599
600
601
602
603
604
605
606
607
            if self._wandb.run is None:
                self._wandb.init(
                    project=os.getenv("WANDB_PROJECT", "huggingface"),
                    name=run_name,
                    **init_args,
                )
            # add config parameters (run may have been created manually)
            self._wandb.config.update(combined_dict, allow_val_change=True)

            # define default x-axis (for latest wandb versions)
            if getattr(self._wandb, "define_metric", None):
                self._wandb.define_metric("train/global_step")
                self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
608

Sylvain Gugger's avatar
Sylvain Gugger committed
609
610
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
611
612
613
                self._wandb.watch(
                    model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
614
615

    def on_train_begin(self, args, state, control, model=None, **kwargs):
616
617
        if self._wandb is None:
            return
618
        hp_search = state.is_hyper_param_search
619
620
        if hp_search:
            self._wandb.finish()
621
            self._initialized = False
622
            args.run_name = None
623
624
        if not self._initialized:
            self.setup(args, state, model, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
625

626
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
627
628
        if self._wandb is None:
            return
629
630
631
632
633
634
635
636
637
        if self._log_model and self._initialized and state.is_world_process_zero:
            from .trainer import Trainer

            fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
            with tempfile.TemporaryDirectory() as temp_dir:
                fake_trainer.save_model(temp_dir)
                metadata = (
                    {
                        k: v
638
                        for k, v in dict(self._wandb.summary).items()
639
640
641
642
643
644
645
646
                        if isinstance(v, numbers.Number) and not k.startswith("_")
                    }
                    if not args.load_best_model_at_end
                    else {
                        f"eval/{args.metric_for_best_model}": state.best_metric,
                        "train/total_floss": state.total_flos,
                    }
                )
647
                artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
648
649
650
651
                for f in Path(temp_dir).glob("*"):
                    if f.is_file():
                        with artifact.new_file(f.name, mode="wb") as fa:
                            fa.write(f.read_bytes())
652
                self._wandb.run.log_artifact(artifact)
653

Sylvain Gugger's avatar
Sylvain Gugger committed
654
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
655
656
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
657
        if not self._initialized:
658
            self.setup(args, state, model)
Sylvain Gugger's avatar
Sylvain Gugger committed
659
        if state.is_world_process_zero:
660
            logs = rewrite_logs(logs)
661
            self._wandb.log({**logs, "train/global_step": state.global_step})
Sylvain Gugger's avatar
Sylvain Gugger committed
662
663
664
665


class CometCallback(TrainerCallback):
    """
666
    A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
Sylvain Gugger's avatar
Sylvain Gugger committed
667
668
669
    """

    def __init__(self):
670
671
        if not _has_comet:
            raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
Sylvain Gugger's avatar
Sylvain Gugger committed
672
        self._initialized = False
673
        self._log_assets = False
Sylvain Gugger's avatar
Sylvain Gugger committed
674
675
676
677
678
679

    def setup(self, args, state, model):
        """
        Setup the optional Comet.ml integration.

        Environment:
680
            COMET_MODE (`str`, *optional*):
681
682
                Whether to create an online, offline experiment or disable Comet logging. Can be "OFFLINE", "ONLINE",
                or "DISABLED". Defaults to "ONLINE".
683
            COMET_PROJECT_NAME (`str`, *optional*):
684
                Comet project name for experiments
685
686
687
            COMET_OFFLINE_DIRECTORY (`str`, *optional*):
                Folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
            COMET_LOG_ASSETS (`str`, *optional*):
688
689
                Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be "TRUE", or
                "FALSE". Defaults to "TRUE".
Sylvain Gugger's avatar
Sylvain Gugger committed
690

Sylvain Gugger's avatar
Sylvain Gugger committed
691
692
        For a number of configurable items in the environment, see
        [here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
Sylvain Gugger's avatar
Sylvain Gugger committed
693
694
        """
        self._initialized = True
695
696
697
        log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
        if log_assets in {"TRUE", "1"}:
            self._log_assets = True
Sylvain Gugger's avatar
Sylvain Gugger committed
698
699
700
        if state.is_world_process_zero:
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            experiment = None
701
            experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
Sylvain Gugger's avatar
Sylvain Gugger committed
702
            if comet_mode == "ONLINE":
703
704
                experiment = comet_ml.Experiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
705
706
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
707
708
709
                experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(model, framework="transformers")
                experiment._log_parameters(args, prefix="args/", framework="transformers")
                if hasattr(model, "config"):
                    experiment._log_parameters(model.config, prefix="config/", framework="transformers")

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if experiment is not None:
                experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
728

729
730
731
    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
732
733
734
735
736
737
738
            if experiment is not None:
                if self._log_assets is True:
                    logger.info("Logging checkpoints. This may take time.")
                    experiment.log_asset_folder(
                        args.output_dir, recursive=True, log_file_name=True, step=state.global_step
                    )
                experiment.end()
739

740

741
742
class AzureMLCallback(TrainerCallback):
    """
743
    A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
744
745
746
    """

    def __init__(self, azureml_run=None):
747
748
        if not is_azureml_available():
            raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
749
750
751
        self.azureml_run = azureml_run

    def on_init_end(self, args, state, control, **kwargs):
752
753
        from azureml.core.run import Run

754
755
756
757
        if self.azureml_run is None and state.is_world_process_zero:
            self.azureml_run = Run.get_context()

    def on_log(self, args, state, control, logs=None, **kwargs):
758
        if self.azureml_run and state.is_world_process_zero:
759
760
761
762
763
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.azureml_run.log(k, v, description=k)


764
765
class MLflowCallback(TrainerCallback):
    """
766
767
    A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
    environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
768
769
770
    """

    def __init__(self):
771
772
        if not is_mlflow_available():
            raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
773
774
        import mlflow

775
776
777
        self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
        self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH

778
        self._initialized = False
779
        self._auto_end_run = False
780
        self._log_artifacts = False
781
        self._ml_flow = mlflow
782
783
784
785
786
787

    def setup(self, args, state, model):
        """
        Setup the optional MLflow integration.

        Environment:
788
            HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
789
790
791
792
793
794
795
796
797
                Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
                remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
                [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
                storage will just copy the files to your artifact location.
            MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
                Whether to use an MLflow experiment_name under which to launch the run. Default to "None" which will
                point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment
                to be activated. If an experiment with this name does not exist, a new experiment with this name is
                created.
798
799
800
801
802
803
804
805
806
807
            MLFLOW_TAGS (`str`, *optional*):
                A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example:
                os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'
            MLFLOW_NESTED_RUN (`str`, *optional*):
                Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current
                run.
            MLFLOW_RUN_ID (`str`, *optional*):
                Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint.
                When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified
                run ID and other parameters are ignored.
808
809
            MLFLOW_FLATTEN_PARAMS (`str`, *optional*):
                Whether to flatten the parameters dictionary before logging. Default to `False`.
810
        """
811
812
813
        self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
        self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
        self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
814
        self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
815
816
        self._run_id = os.getenv("MLFLOW_RUN_ID", None)
        logger.debug(
Sylvain Gugger's avatar
Sylvain Gugger committed
817
818
            f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
            f" tags={self._nested_run}"
819
        )
820
        if state.is_world_process_zero:
821
822
            if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
                if self._experiment_name:
823
                    # Use of set_experiment() ensure that Experiment is created if not exists
824
825
826
827
                    self._ml_flow.set_experiment(self._experiment_name)
                self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run)
                logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}")
                self._auto_end_run = True
828
829
830
831
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
832
            combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict
833
834
835
836
837
            # remove params that are too long for MLflow
            for name, value in list(combined_dict.items()):
                # internally, all values are converted to str in MLflow
                if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
                    logger.warning(
Sylvain Gugger's avatar
Sylvain Gugger committed
838
839
840
841
                        f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s'
                        " log_param() only accepts values no longer than 250 characters so we dropped this attribute."
                        " You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and"
                        " avoid this message."
842
843
                    )
                    del combined_dict[name]
844
845
            # MLflow cannot log more than 100 values in one go, so we have to split it
            combined_dict_items = list(combined_dict.items())
846
847
            for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
                self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
848
849
850
851
            mlflow_tags = os.getenv("MLFLOW_TAGS", None)
            if mlflow_tags:
                mlflow_tags = json.loads(mlflow_tags)
                self._ml_flow.set_tags(mlflow_tags)
852
853
854
855
856
857
858
859
860
861
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
862
            metrics = {}
863
864
            for k, v in logs.items():
                if isinstance(v, (int, float)):
865
                    metrics[k] = v
866
867
                else:
                    logger.warning(
868
                        f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
Sylvain Gugger's avatar
Sylvain Gugger committed
869
                        "MLflow's log_metric() only accepts float and int types so we dropped this attribute."
870
                    )
871
            self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
872
873
874
875
876

    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            if self._log_artifacts:
                logger.info("Logging artifacts. This may take time.")
877
                self._ml_flow.log_artifacts(args.output_dir)
878
879
            if self._auto_end_run and self._ml_flow.active_run():
                self._ml_flow.end_run()
880
881
882
883

    def __del__(self):
        # if the previous run is not terminated correctly, the fluent API will
        # not let you start a new run before the previous one is killed
884
885
886
887
888
        if (
            self._auto_end_run
            and callable(getattr(self._ml_flow, "active_run", None))
            and self._ml_flow.active_run() is not None
        ):
889
            self._ml_flow.end_run()
890
891


892
893
class NeptuneCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
894
    A [`TrainerCallback`] that sends the logs to [Neptune](https://neptune.ai).
895
896
897
    """

    def __init__(self):
898
899
900
901
        if not is_neptune_available():
            raise ValueError(
                "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
            )
902
903
904
905
906
907
908
909
910
911
912
        import neptune.new as neptune

        self._neptune = neptune
        self._initialized = False
        self._log_artifacts = False

    def setup(self, args, state, model):
        """
        Setup the Neptune integration.

        Environment:
913
914
915
            NEPTUNE_PROJECT (`str`, *required*):
                The project ID for neptune.ai account. Should be in format *workspace_name/project_name*
            NEPTUNE_API_TOKEN (`str`, *required*):
916
                API-token for neptune.ai account
917
918
919
            NEPTUNE_CONNECTION_MODE (`str`, *optional*):
                Neptune connection mode. *async* by default
            NEPTUNE_RUN_NAME (`str`, *optional*):
920
921
922
923
924
925
926
927
                The name of run process on Neptune dashboard
        """
        if state.is_world_process_zero:
            self._neptune_run = self._neptune.init(
                project=os.getenv("NEPTUNE_PROJECT"),
                api_token=os.getenv("NEPTUNE_API_TOKEN"),
                mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
                name=os.getenv("NEPTUNE_RUN_NAME", None),
928
                run=os.getenv("NEPTUNE_RUN_ID", None),
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
            )
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            self._neptune_run["parameters"] = combined_dict
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            for k, v in logs.items():
                self._neptune_run[k].log(v, step=state.global_step)

    def __del__(self):
        """
        Environment:
951
            NEPTUNE_STOP_TIMEOUT (`int`, *optional*):
952
953
954
955
956
957
958
959
960
961
962
                Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
                run. If not set it will wait for all tracking calls to finish.
        """
        try:
            stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
            stop_timeout = int(stop_timeout) if stop_timeout else None
            self._neptune_run.stop(seconds=stop_timeout)
        except AttributeError:
            pass


963
964
class CodeCarbonCallback(TrainerCallback):
    """
965
    A [`TrainerCallback`] that tracks the CO2 emission of training.
966
967
968
    """

    def __init__(self):
969
970
971
972
        if not is_codecarbon_available():
            raise RuntimeError(
                "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
            )
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
        import codecarbon

        self._codecarbon = codecarbon
        self.tracker = None

    def on_init_end(self, args, state, control, **kwargs):
        if self.tracker is None and state.is_local_process_zero:
            # CodeCarbon will automatically handle environment variables for configuration
            self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.start()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.stop()


992
993
994
995
INTEGRATION_TO_CALLBACK = {
    "azure_ml": AzureMLCallback,
    "comet_ml": CometCallback,
    "mlflow": MLflowCallback,
996
    "neptune": NeptuneCallback,
997
998
    "tensorboard": TensorBoardCallback,
    "wandb": WandbCallback,
999
    "codecarbon": CodeCarbonCallback,
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
}


def get_reporting_integration_callbacks(report_to):
    for integration in report_to:
        if integration not in INTEGRATION_TO_CALLBACK:
            raise ValueError(
                f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
            )
    return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]