integrations.py 42 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
17
import functools
18
import importlib.util
19
import json
20
import numbers
21
import os
22
import sys
23
24
import tempfile
from pathlib import Path
25

26
from .utils import flatten_dict, is_datasets_available, logging
27
28
29
30


logger = logging.get_logger(__name__)

Sylvain Gugger's avatar
Sylvain Gugger committed
31

32
# comet_ml requires to be imported before any ML frameworks
33
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
34
if _has_comet:
35
    try:
36
        import comet_ml  # noqa: F401
37

38
39
40
41
42
43
44
45
        if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
            _has_comet = True
        else:
            if os.getenv("COMET_MODE", "").upper() != "DISABLED":
                logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
            _has_comet = False
    except (ImportError, ValueError):
        _has_comet = False
46

47
from .trainer_callback import ProgressCallback, TrainerCallback  # noqa: E402
48
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy  # noqa: E402
49
from .utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available  # noqa: E402
50
51


Sylvain Gugger's avatar
Sylvain Gugger committed
52
# Integration functions:
53
def is_wandb_available():
54
    # any value of WANDB_DISABLED disables wandb
55
    if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
56
        logger.warning(
57
58
59
            "Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
            "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
        )
60
61
        return False
    return importlib.util.find_spec("wandb") is not None
62
63
64
65
66
67
68


def is_comet_available():
    return _has_comet


def is_tensorboard_available():
69
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
70
71
72


def is_optuna_available():
73
    return importlib.util.find_spec("optuna") is not None
74
75
76


def is_ray_available():
77
    return importlib.util.find_spec("ray") is not None
78
79


80
def is_ray_tune_available():
81
82
83
    if not is_ray_available():
        return False
    return importlib.util.find_spec("ray.tune") is not None
84
85


86
87
88
89
def is_sigopt_available():
    return importlib.util.find_spec("sigopt") is not None


90
def is_azureml_available():
91
92
93
94
95
    if importlib.util.find_spec("azureml") is None:
        return False
    if importlib.util.find_spec("azureml.core") is None:
        return False
    return importlib.util.find_spec("azureml.core.run") is not None
96
97


98
def is_mlflow_available():
99
100
    if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
        return False
101
    return importlib.util.find_spec("mlflow") is not None
102
103


104
def is_fairscale_available():
105
    return importlib.util.find_spec("fairscale") is not None
106
107


108
109
110
111
def is_neptune_available():
    return importlib.util.find_spec("neptune") is not None


112
113
114
115
def is_codecarbon_available():
    return importlib.util.find_spec("codecarbon") is not None


116
117
def hp_params(trial):
    if is_optuna_available():
118
119
        import optuna

120
121
        if isinstance(trial, optuna.Trial):
            return trial.params
122
    if is_ray_tune_available():
123
124
125
        if isinstance(trial, dict):
            return trial

126
127
128
129
    if is_sigopt_available():
        if isinstance(trial, dict):
            return trial

130
131
132
133
    if is_wandb_available():
        if isinstance(trial, dict):
            return trial

134
135
136
    raise RuntimeError(f"Unknown type for trial {trial.__class__}")


137
138
139
def default_hp_search_backend():
    if is_optuna_available():
        return "optuna"
140
    elif is_ray_tune_available():
141
        return "ray"
142
143
    elif is_sigopt_available():
        return "sigopt"
144
145


146
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
147
148
    import optuna

149
    def _objective(trial, checkpoint_dir=None):
150
        checkpoint = None
151
152
153
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
154
                    checkpoint = os.path.join(checkpoint_dir, subdir)
155
        trainer.objective = None
156
        trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
157
158
159
160
161
162
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

163
164
165
166
167
168
169
170
171
    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
172
173
    import ray

174
    def _objective(trial, local_trainer, checkpoint_dir=None):
175
176
177
178
179
180
181
182
        try:
            from transformers.utils.notebook import NotebookProgressCallback

            if local_trainer.pop_callback(NotebookProgressCallback):
                local_trainer.add_callback(ProgressCallback)
        except ModuleNotFoundError:
            pass

183
        checkpoint = None
184
185
186
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
187
                    checkpoint = os.path.join(checkpoint_dir, subdir)
188
        local_trainer.objective = None
189
        local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
190
        # If there hasn't been any evaluation during the training loop.
191
192
193
194
195
        if getattr(local_trainer, "objective", None) is None:
            metrics = local_trainer.evaluate()
            local_trainer.objective = local_trainer.compute_objective(metrics)
            local_trainer._tune_save_checkpoint()
            ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
196

197
198
199
200
201
202
203
204
205
206
    if not trainer._memory_tracker.skip_memory_metrics:
        from .trainer_utils import TrainerMemoryTracker

        logger.warning(
            "Memory tracking for your Trainer is currently "
            "enabled. Automatically disabling the memory tracker "
            "since the memory tracker is not serializable."
        )
        trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)

207
208
    # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
    # while doing the ray hp search.
Sylvain Gugger's avatar
Sylvain Gugger committed
209
    _tb_writer = trainer.pop_callback(TensorBoardCallback)
210
    trainer.model = None
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    # Setup default `resources_per_trial`.
    if "resources_per_trial" not in kwargs:
        # Default to 1 CPU and 1 GPU (if applicable) per trial.
        kwargs["resources_per_trial"] = {"cpu": 1}
        if trainer.args.n_gpu > 0:
            kwargs["resources_per_trial"]["gpu"] = 1
        resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
        logger.info(
            "No `resources_per_trial` arg was passed into "
            "`hyperparameter_search`. Setting it to a default value "
            f"of {resource_msg} for each trial."
        )
    # Make sure each trainer only uses GPUs that were allocated per trial.
    gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
    trainer.args._n_gpu = gpus_per_trial
227

228
    # Setup default `progress_reporter`.
229
    if "progress_reporter" not in kwargs:
230
231
232
233
234
235
236
237
        from ray.tune import CLIReporter

        kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
    if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
        # `keep_checkpoints_num=0` would disabled checkpointing
        trainer.use_tune_checkpoints = True
        if kwargs["keep_checkpoints_num"] > 1:
            logger.warning(
238
                f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
239
                "Checkpoints are usually huge, "
240
                "consider setting `keep_checkpoints_num=1`."
241
            )
242
243
    if "scheduler" in kwargs:
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
244

245
246
247
248
249
250
251
252
        # Check if checkpointing is enabled for PopulationBasedTraining
        if isinstance(kwargs["scheduler"], PopulationBasedTraining):
            if not trainer.use_tune_checkpoints:
                logger.warning(
                    "You are using PopulationBasedTraining but you haven't enabled checkpointing. "
                    "This means your trials will train from scratch everytime they are exploiting "
                    "new configurations. Consider enabling checkpointing by passing "
                    "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
253
254
                )

255
256
257
        # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
        if isinstance(
            kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
258
        ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
259
260
261
262
263
            raise RuntimeError(
                "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
                "This means your trials will not report intermediate results to Ray Tune, and "
                "can thus not be stopped early or used to exploit other trials parameters. "
                "If this is what you want, do not use {cls}. If you would like to use {cls}, "
264
                "make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
265
266
                "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
            )
267

268
269
270
271
272
    trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)

    @functools.wraps(trainable)
    def dynamic_modules_import_trainable(*args, **kwargs):
        """
273
        Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
274
275
276

        Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.

277
        Assumes that `_objective`, defined above, is a function.
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        """
        if is_datasets_available():
            import datasets.load

            dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
            # load dynamic_modules from path
            spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
            datasets_modules = importlib.util.module_from_spec(spec)
            sys.modules[spec.name] = datasets_modules
            spec.loader.exec_module(datasets_modules)
        return trainable(*args, **kwargs)

    # special attr set by tune.with_parameters
    if hasattr(trainable, "__mixins__"):
        dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__

294
    analysis = ray.tune.run(
295
        dynamic_modules_import_trainable,
296
297
298
299
        config=trainer.hp_space(None),
        num_samples=n_trials,
        **kwargs,
    )
300
301
    best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
    best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
    if _tb_writer is not None:
        trainer.add_callback(_tb_writer)
304
    return best_run
Sylvain Gugger's avatar
Sylvain Gugger committed
305
306


307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:

    from sigopt import Connection

    conn = Connection()
    proxies = kwargs.pop("proxies", None)
    if proxies is not None:
        conn.set_proxies(proxies)

    experiment = conn.experiments().create(
        name="huggingface-tune",
        parameters=trainer.hp_space(None),
        metrics=[dict(name="objective", objective=direction, strategy="optimize")],
        parallel_bandwidth=1,
        observation_budget=n_trials,
        project="huggingface",
    )
    logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")

    while experiment.progress.observation_count < experiment.observation_budget:
        suggestion = conn.experiments(experiment.id).suggestions().create()
        trainer.objective = None
        trainer.train(resume_from_checkpoint=None, trial=suggestion)
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)

        values = [dict(name="objective", value=trainer.objective)]
        obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
        logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
        experiment = conn.experiments(experiment.id).fetch()

    best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
    best_run = BestRun(best.id, best.value, best.assignments)

    return best_run


346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
    from .integrations import is_wandb_available

    if not is_wandb_available():
        raise ImportError("This function needs wandb installed: `pip install wandb`")
    import wandb

    # add WandbCallback if not already added in trainer callbacks
    reporting_to_wandb = False
    for callback in trainer.callback_handler.callbacks:
        if isinstance(callback, WandbCallback):
            reporting_to_wandb = True
            break
    if not reporting_to_wandb:
        trainer.add_callback(WandbCallback())
    trainer.args.report_to = "wandb"
    best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
    sweep_id = kwargs.pop("sweep_id", None)
    project = kwargs.pop("project", None)
    name = kwargs.pop("name", None)
    entity = kwargs.pop("entity", None)
    metric = kwargs.pop("metric", "eval/loss")

    sweep_config = trainer.hp_space(None)
    sweep_config["metric"]["goal"] = direction
    sweep_config["metric"]["name"] = metric
    if name:
        sweep_config["name"] = name

    def _objective():

        run = wandb.run if wandb.run else wandb.init()
        trainer.state.trial_name = run.name
        run.config.update({"assignments": {}, "metric": metric})
        config = wandb.config

        trainer.objective = None

        trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
            format_metrics = rewrite_logs(metrics)
            if metric not in format_metrics:
                logger.warning(
                    f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available metrics are {format_metrics.keys()}"
                )
        best_score = False
        if best_trial["run_id"] is not None:
            if direction == "minimize":
                best_score = trainer.objective < best_trial["objective"]
            elif direction == "maximize":
                best_score = trainer.objective > best_trial["objective"]

        if best_score or best_trial["run_id"] is None:
            best_trial["run_id"] = run.id
            best_trial["objective"] = trainer.objective
            best_trial["hyperparameters"] = dict(config)

        return trainer.objective

    sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
    logger.info(f"wandb sweep id - {sweep_id}")
    wandb.agent(sweep_id, function=_objective, count=n_trials)

    return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])


415
416
417
418
419
420
421
422
423
424
425
426
def get_available_reporting_integrations():
    integrations = []
    if is_azureml_available():
        integrations.append("azure_ml")
    if is_comet_available():
        integrations.append("comet_ml")
    if is_mlflow_available():
        integrations.append("mlflow")
    if is_tensorboard_available():
        integrations.append("tensorboard")
    if is_wandb_available():
        integrations.append("wandb")
427
428
    if is_codecarbon_available():
        integrations.append("codecarbon")
429
430
431
    return integrations


432
433
434
435
def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
436
437
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
438
439
440
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
441
442
        elif k.startswith(test_prefix):
            new_d["test/" + k[test_prefix_len:]] = v
443
444
445
446
447
        else:
            new_d["train/" + k] = v
    return new_d


Sylvain Gugger's avatar
Sylvain Gugger committed
448
449
class TensorBoardCallback(TrainerCallback):
    """
450
    A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
Sylvain Gugger's avatar
Sylvain Gugger committed
451
452

    Args:
453
        tb_writer (`SummaryWriter`, *optional*):
Tiger's avatar
Tiger committed
454
            The writer to use. Will instantiate one if not set.
Sylvain Gugger's avatar
Sylvain Gugger committed
455
456
457
    """

    def __init__(self, tb_writer=None):
458
        has_tensorboard = is_tensorboard_available()
459
460
461
462
        if not has_tensorboard:
            raise RuntimeError(
                "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
            )
463
464
465
466
467
468
469
470
471
472
473
474
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
max yue's avatar
max yue committed
475
476
        else:
            self._SummaryWriter = None
Sylvain Gugger's avatar
Sylvain Gugger committed
477
478
        self.tb_writer = tb_writer

479
480
    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
481
482
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
483
484

    def on_train_begin(self, args, state, control, **kwargs):
485
486
487
488
489
490
491
492
493
494
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

495
496
        if self.tb_writer is None:
            self._init_summary_writer(args, log_dir)
497

Sylvain Gugger's avatar
Sylvain Gugger committed
498
499
        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
500
501
502
503
504
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
505
506
507
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
Sylvain Gugger's avatar
Sylvain Gugger committed
508
509

    def on_log(self, args, state, control, logs=None, **kwargs):
510
511
512
513
514
        if not state.is_world_process_zero:
            return

        if self.tb_writer is None:
            self._init_summary_writer(args)
515

516
        if self.tb_writer is not None:
517
            logs = rewrite_logs(logs)
Sylvain Gugger's avatar
Sylvain Gugger committed
518
519
520
521
522
523
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
524
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
Sylvain Gugger's avatar
Sylvain Gugger committed
525
                        "This invocation of Tensorboard's writer.add_scalar() "
526
                        "is incorrect so we dropped this attribute."
Sylvain Gugger's avatar
Sylvain Gugger committed
527
528
529
530
531
532
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
533
            self.tb_writer = None
Sylvain Gugger's avatar
Sylvain Gugger committed
534
535
536
537


class WandbCallback(TrainerCallback):
    """
538
    A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/).
Sylvain Gugger's avatar
Sylvain Gugger committed
539
540
541
    """

    def __init__(self):
542
        has_wandb = is_wandb_available()
543
544
        if not has_wandb:
            raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
545
546
547
        if has_wandb:
            import wandb

548
            self._wandb = wandb
Sylvain Gugger's avatar
Sylvain Gugger committed
549
        self._initialized = False
550
551
        # log outputs
        self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
Sylvain Gugger's avatar
Sylvain Gugger committed
552

553
    def setup(self, args, state, model, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
554
        """
555
        Setup the optional Weights & Biases (*wandb*) integration.
Sylvain Gugger's avatar
Sylvain Gugger committed
556

Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
559
        One can subclass and override this method to customize the setup if needed. Find more information
        [here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment
        variables:
Sylvain Gugger's avatar
Sylvain Gugger committed
560
561

        Environment:
562
            WANDB_LOG_MODEL (`bool`, *optional*, defaults to `False`):
563
                Whether or not to log model as artifact at the end of training. Use along with
564
565
                *TrainingArguments.load_best_model_at_end* to upload best model.
            WANDB_WATCH (`str`, *optional* defaults to `"gradients"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
566
567
                Can be `"gradients"`, `"all"` or `"false"`. Set to `"false"` to disable gradient logging or `"all"` to
                log gradients and parameters.
568
            WANDB_PROJECT (`str`, *optional*, defaults to `"huggingface"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
569
                Set this to a custom string to store results in a different project.
570
571
            WANDB_DISABLED (`bool`, *optional*, defaults to `False`):
                Whether or not to disable wandb entirely. Set *WANDB_DISABLED=true* to disable.
Sylvain Gugger's avatar
Sylvain Gugger committed
572
        """
573
574
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
575
576
577
578
579
580
        self._initialized = True
        if state.is_world_process_zero:
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**args.to_sanitized_dict()}
581
582
583
584
585
586
587
588
589
590
591
592

            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            trial_name = state.trial_name
            init_args = {}
            if trial_name is not None:
                run_name = trial_name
                init_args["group"] = args.run_name
            else:
                run_name = args.run_name

593
594
595
596
597
598
599
600
601
602
603
604
605
            if self._wandb.run is None:
                self._wandb.init(
                    project=os.getenv("WANDB_PROJECT", "huggingface"),
                    name=run_name,
                    **init_args,
                )
            # add config parameters (run may have been created manually)
            self._wandb.config.update(combined_dict, allow_val_change=True)

            # define default x-axis (for latest wandb versions)
            if getattr(self._wandb, "define_metric", None):
                self._wandb.define_metric("train/global_step")
                self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
606

Sylvain Gugger's avatar
Sylvain Gugger committed
607
608
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
609
610
611
                self._wandb.watch(
                    model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
612
613

    def on_train_begin(self, args, state, control, model=None, **kwargs):
614
615
        if self._wandb is None:
            return
616
        hp_search = state.is_hyper_param_search
617
618
        if hp_search:
            self._wandb.finish()
619
            self._initialized = False
620
            args.run_name = None
621
622
        if not self._initialized:
            self.setup(args, state, model, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
623

624
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
625
626
        if self._wandb is None:
            return
627
628
629
630
631
632
633
634
635
        if self._log_model and self._initialized and state.is_world_process_zero:
            from .trainer import Trainer

            fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
            with tempfile.TemporaryDirectory() as temp_dir:
                fake_trainer.save_model(temp_dir)
                metadata = (
                    {
                        k: v
636
                        for k, v in dict(self._wandb.summary).items()
637
638
639
640
641
642
643
644
                        if isinstance(v, numbers.Number) and not k.startswith("_")
                    }
                    if not args.load_best_model_at_end
                    else {
                        f"eval/{args.metric_for_best_model}": state.best_metric,
                        "train/total_floss": state.total_flos,
                    }
                )
645
                artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
646
647
648
649
                for f in Path(temp_dir).glob("*"):
                    if f.is_file():
                        with artifact.new_file(f.name, mode="wb") as fa:
                            fa.write(f.read_bytes())
650
                self._wandb.run.log_artifact(artifact)
651

Sylvain Gugger's avatar
Sylvain Gugger committed
652
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
653
654
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
655
        if not self._initialized:
656
            self.setup(args, state, model)
Sylvain Gugger's avatar
Sylvain Gugger committed
657
        if state.is_world_process_zero:
658
            logs = rewrite_logs(logs)
659
            self._wandb.log({**logs, "train/global_step": state.global_step})
Sylvain Gugger's avatar
Sylvain Gugger committed
660
661
662
663


class CometCallback(TrainerCallback):
    """
664
    A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
Sylvain Gugger's avatar
Sylvain Gugger committed
665
666
667
    """

    def __init__(self):
668
669
        if not _has_comet:
            raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
Sylvain Gugger's avatar
Sylvain Gugger committed
670
        self._initialized = False
671
        self._log_assets = False
Sylvain Gugger's avatar
Sylvain Gugger committed
672
673
674
675
676
677

    def setup(self, args, state, model):
        """
        Setup the optional Comet.ml integration.

        Environment:
678
            COMET_MODE (`str`, *optional*):
679
680
                Whether to create an online, offline experiment or disable Comet logging. Can be "OFFLINE", "ONLINE",
                or "DISABLED". Defaults to "ONLINE".
681
            COMET_PROJECT_NAME (`str`, *optional*):
682
                Comet project name for experiments
683
684
685
            COMET_OFFLINE_DIRECTORY (`str`, *optional*):
                Folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
            COMET_LOG_ASSETS (`str`, *optional*):
686
687
                Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be "TRUE", or
                "FALSE". Defaults to "TRUE".
Sylvain Gugger's avatar
Sylvain Gugger committed
688

Sylvain Gugger's avatar
Sylvain Gugger committed
689
690
        For a number of configurable items in the environment, see
        [here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
Sylvain Gugger's avatar
Sylvain Gugger committed
691
692
        """
        self._initialized = True
693
694
695
        log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
        if log_assets in {"TRUE", "1"}:
            self._log_assets = True
Sylvain Gugger's avatar
Sylvain Gugger committed
696
697
698
        if state.is_world_process_zero:
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            experiment = None
699
            experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
Sylvain Gugger's avatar
Sylvain Gugger committed
700
            if comet_mode == "ONLINE":
701
702
                experiment = comet_ml.Experiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
703
704
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
705
706
707
                experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(model, framework="transformers")
                experiment._log_parameters(args, prefix="args/", framework="transformers")
                if hasattr(model, "config"):
                    experiment._log_parameters(model.config, prefix="config/", framework="transformers")

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if experiment is not None:
                experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
726

727
728
729
730
731
732
733
734
735
736
    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if (experiment is not None) and (self._log_assets is True):
                logger.info("Logging checkpoints. This may take time.")
                experiment.log_asset_folder(
                    args.output_dir, recursive=True, log_file_name=True, step=state.global_step
                )
            experiment.end()

737

738
739
class AzureMLCallback(TrainerCallback):
    """
740
    A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
741
742
743
    """

    def __init__(self, azureml_run=None):
744
745
        if not is_azureml_available():
            raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
746
747
748
        self.azureml_run = azureml_run

    def on_init_end(self, args, state, control, **kwargs):
749
750
        from azureml.core.run import Run

751
752
753
754
        if self.azureml_run is None and state.is_world_process_zero:
            self.azureml_run = Run.get_context()

    def on_log(self, args, state, control, logs=None, **kwargs):
755
        if self.azureml_run and state.is_world_process_zero:
756
757
758
759
760
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.azureml_run.log(k, v, description=k)


761
762
class MLflowCallback(TrainerCallback):
    """
763
764
    A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
    environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
765
766
767
    """

    def __init__(self):
768
769
        if not is_mlflow_available():
            raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
770
771
        import mlflow

772
773
774
        self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
        self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH

775
        self._initialized = False
776
        self._auto_end_run = False
777
        self._log_artifacts = False
778
        self._ml_flow = mlflow
779
780
781
782
783
784

    def setup(self, args, state, model):
        """
        Setup the optional MLflow integration.

        Environment:
785
            HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
786
787
788
789
790
791
792
793
794
                Whether to use MLflow .log_artifact() facility to log artifacts. This only makes sense if logging to a
                remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy whatever is in
                [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it without a remote
                storage will just copy the files to your artifact location.
            MLFLOW_EXPERIMENT_NAME (`str`, *optional*):
                Whether to use an MLflow experiment_name under which to launch the run. Default to "None" which will
                point to the "Default" experiment in MLflow. Otherwise, it is a case sensitive name of the experiment
                to be activated. If an experiment with this name does not exist, a new experiment with this name is
                created.
795
796
797
798
799
800
801
802
803
804
            MLFLOW_TAGS (`str`, *optional*):
                A string dump of a dictionary of key/value pair to be added to the MLflow run as tags. Example:
                os.environ['MLFLOW_TAGS']='{"release.candidate": "RC1", "release.version": "2.2.0"}'
            MLFLOW_NESTED_RUN (`str`, *optional*):
                Whether to use MLflow nested runs. If set to `True` or *1*, will create a nested run inside the current
                run.
            MLFLOW_RUN_ID (`str`, *optional*):
                Allow to reattach to an existing run which can be usefull when resuming training from a checkpoint.
                When MLFLOW_RUN_ID environment variable is set, start_run attempts to resume a run with the specified
                run ID and other parameters are ignored.
805
806
            MLFLOW_FLATTEN_PARAMS (`str`, *optional*):
                Whether to flatten the parameters dictionary before logging. Default to `False`.
807
        """
808
809
810
        self._log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
        self._nested_run = os.getenv("MLFLOW_NESTED_RUN", "FALSE").upper() in ENV_VARS_TRUE_VALUES
        self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
811
        self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
812
813
814
815
        self._run_id = os.getenv("MLFLOW_RUN_ID", None)
        logger.debug(
            f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run}, tags={self._nested_run}"
        )
816
        if state.is_world_process_zero:
817
818
            if self._ml_flow.active_run() is None or self._nested_run or self._run_id:
                if self._experiment_name:
819
                    # Use of set_experiment() ensure that Experiment is created if not exists
820
821
822
823
                    self._ml_flow.set_experiment(self._experiment_name)
                self._ml_flow.start_run(run_name=args.run_name, nested=self._nested_run)
                logger.debug(f"MLflow run started with run_id={self._ml_flow.active_run().info.run_id}")
                self._auto_end_run = True
824
825
826
827
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
828
            combined_dict = flatten_dict(combined_dict) if self._flatten_params else combined_dict
829
830
831
832
833
            # remove params that are too long for MLflow
            for name, value in list(combined_dict.items()):
                # internally, all values are converted to str in MLflow
                if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
                    logger.warning(
834
835
836
                        f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. '
                        f"MLflow's log_param() only accepts values no longer than 250 characters so we dropped this attribute. "
                        f"You can use `MLFLOW_FLATTEN_PARAMS` environment variable to flatten the parameters and avoid this message."
837
838
                    )
                    del combined_dict[name]
839
840
            # MLflow cannot log more than 100 values in one go, so we have to split it
            combined_dict_items = list(combined_dict.items())
841
842
            for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
                self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
843
844
845
846
            mlflow_tags = os.getenv("MLFLOW_TAGS", None)
            if mlflow_tags:
                mlflow_tags = json.loads(mlflow_tags)
                self._ml_flow.set_tags(mlflow_tags)
847
848
849
850
851
852
853
854
855
856
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
857
            metrics = {}
858
859
            for k, v in logs.items():
                if isinstance(v, (int, float)):
860
                    metrics[k] = v
861
862
                else:
                    logger.warning(
863
864
                        f'Trainer is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. '
                        f"MLflow's log_metric() only accepts float and int types so we dropped this attribute."
865
                    )
866
            self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
867
868
869
870
871

    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            if self._log_artifacts:
                logger.info("Logging artifacts. This may take time.")
872
                self._ml_flow.log_artifacts(args.output_dir)
873
874
            if self._auto_end_run and self._ml_flow.active_run():
                self._ml_flow.end_run()
875
876
877
878

    def __del__(self):
        # if the previous run is not terminated correctly, the fluent API will
        # not let you start a new run before the previous one is killed
879
        if self._auto_end_run and self._ml_flow and self._ml_flow.active_run() is not None:
880
            self._ml_flow.end_run()
881
882


883
884
class NeptuneCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
885
    A [`TrainerCallback`] that sends the logs to [Neptune](https://neptune.ai).
886
887
888
    """

    def __init__(self):
889
890
891
892
        if not is_neptune_available():
            raise ValueError(
                "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
            )
893
894
895
896
897
898
899
900
901
902
903
        import neptune.new as neptune

        self._neptune = neptune
        self._initialized = False
        self._log_artifacts = False

    def setup(self, args, state, model):
        """
        Setup the Neptune integration.

        Environment:
904
905
906
            NEPTUNE_PROJECT (`str`, *required*):
                The project ID for neptune.ai account. Should be in format *workspace_name/project_name*
            NEPTUNE_API_TOKEN (`str`, *required*):
907
                API-token for neptune.ai account
908
909
910
            NEPTUNE_CONNECTION_MODE (`str`, *optional*):
                Neptune connection mode. *async* by default
            NEPTUNE_RUN_NAME (`str`, *optional*):
911
912
913
914
915
916
917
918
                The name of run process on Neptune dashboard
        """
        if state.is_world_process_zero:
            self._neptune_run = self._neptune.init(
                project=os.getenv("NEPTUNE_PROJECT"),
                api_token=os.getenv("NEPTUNE_API_TOKEN"),
                mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
                name=os.getenv("NEPTUNE_RUN_NAME", None),
919
                run=os.getenv("NEPTUNE_RUN_ID", None),
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
            )
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            self._neptune_run["parameters"] = combined_dict
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            for k, v in logs.items():
                self._neptune_run[k].log(v, step=state.global_step)

    def __del__(self):
        """
        Environment:
942
            NEPTUNE_STOP_TIMEOUT (`int`, *optional*):
943
944
945
946
947
948
949
950
951
952
953
                Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
                run. If not set it will wait for all tracking calls to finish.
        """
        try:
            stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
            stop_timeout = int(stop_timeout) if stop_timeout else None
            self._neptune_run.stop(seconds=stop_timeout)
        except AttributeError:
            pass


954
955
class CodeCarbonCallback(TrainerCallback):
    """
956
    A [`TrainerCallback`] that tracks the CO2 emission of training.
957
958
959
    """

    def __init__(self):
960
961
962
963
        if not is_codecarbon_available():
            raise RuntimeError(
                "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
            )
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
        import codecarbon

        self._codecarbon = codecarbon
        self.tracker = None

    def on_init_end(self, args, state, control, **kwargs):
        if self.tracker is None and state.is_local_process_zero:
            # CodeCarbon will automatically handle environment variables for configuration
            self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.start()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.stop()


983
984
985
986
INTEGRATION_TO_CALLBACK = {
    "azure_ml": AzureMLCallback,
    "comet_ml": CometCallback,
    "mlflow": MLflowCallback,
987
    "neptune": NeptuneCallback,
988
989
    "tensorboard": TensorBoardCallback,
    "wandb": WandbCallback,
990
    "codecarbon": CodeCarbonCallback,
991
992
993
994
995
996
997
998
999
1000
}


def get_reporting_integration_callbacks(report_to):
    for integration in report_to:
        if integration not in INTEGRATION_TO_CALLBACK:
            raise ValueError(
                f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
            )
    return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]