integrations.py 39.3 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
17
import functools
18
import importlib.util
19
import numbers
20
import os
21
import sys
22
23
import tempfile
from pathlib import Path
24

25
from .file_utils import is_datasets_available
26
27
28
29
30
from .utils import logging


logger = logging.get_logger(__name__)

Sylvain Gugger's avatar
Sylvain Gugger committed
31

32
# comet_ml requires to be imported before any ML frameworks
33
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
34
if _has_comet:
35
    try:
36
        import comet_ml  # noqa: F401
37

38
39
40
41
42
43
44
45
        if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
            _has_comet = True
        else:
            if os.getenv("COMET_MODE", "").upper() != "DISABLED":
                logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
            _has_comet = False
    except (ImportError, ValueError):
        _has_comet = False
46

47
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available  # noqa: E402
48
from .trainer_callback import ProgressCallback, TrainerCallback  # noqa: E402
49
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy  # noqa: E402
50
51


Sylvain Gugger's avatar
Sylvain Gugger committed
52
# Integration functions:
53
def is_wandb_available():
54
    # any value of WANDB_DISABLED disables wandb
55
    if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
56
        logger.warning(
57
58
59
            "Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
            "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
        )
60
61
        return False
    return importlib.util.find_spec("wandb") is not None
62
63
64
65
66
67
68


def is_comet_available():
    return _has_comet


def is_tensorboard_available():
69
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
70
71
72


def is_optuna_available():
73
    return importlib.util.find_spec("optuna") is not None
74
75
76


def is_ray_available():
77
    return importlib.util.find_spec("ray") is not None
78
79


80
def is_ray_tune_available():
81
82
83
    if not is_ray_available():
        return False
    return importlib.util.find_spec("ray.tune") is not None
84
85


86
87
88
89
def is_sigopt_available():
    return importlib.util.find_spec("sigopt") is not None


90
def is_azureml_available():
91
92
93
94
95
    if importlib.util.find_spec("azureml") is None:
        return False
    if importlib.util.find_spec("azureml.core") is None:
        return False
    return importlib.util.find_spec("azureml.core.run") is not None
96
97


98
def is_mlflow_available():
99
100
    if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE":
        return False
101
    return importlib.util.find_spec("mlflow") is not None
102
103


104
def is_fairscale_available():
105
    return importlib.util.find_spec("fairscale") is not None
106
107


108
109
110
111
def is_neptune_available():
    return importlib.util.find_spec("neptune") is not None


112
113
114
115
def is_codecarbon_available():
    return importlib.util.find_spec("codecarbon") is not None


116
117
def hp_params(trial):
    if is_optuna_available():
118
119
        import optuna

120
121
        if isinstance(trial, optuna.Trial):
            return trial.params
122
    if is_ray_tune_available():
123
124
125
        if isinstance(trial, dict):
            return trial

126
127
128
129
    if is_sigopt_available():
        if isinstance(trial, dict):
            return trial

130
131
132
133
    if is_wandb_available():
        if isinstance(trial, dict):
            return trial

134
135
136
    raise RuntimeError(f"Unknown type for trial {trial.__class__}")


137
138
139
def default_hp_search_backend():
    if is_optuna_available():
        return "optuna"
140
    elif is_ray_tune_available():
141
        return "ray"
142
143
    elif is_sigopt_available():
        return "sigopt"
144
145


146
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
147
148
    import optuna

149
    def _objective(trial, checkpoint_dir=None):
150
        checkpoint = None
151
152
153
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
154
                    checkpoint = os.path.join(checkpoint_dir, subdir)
155
        trainer.objective = None
156
        trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
157
158
159
160
161
162
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

163
164
165
166
167
168
169
170
171
    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
172
173
    import ray

174
    def _objective(trial, local_trainer, checkpoint_dir=None):
175
176
177
178
179
180
181
182
        try:
            from transformers.utils.notebook import NotebookProgressCallback

            if local_trainer.pop_callback(NotebookProgressCallback):
                local_trainer.add_callback(ProgressCallback)
        except ModuleNotFoundError:
            pass

183
        checkpoint = None
184
185
186
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
187
                    checkpoint = os.path.join(checkpoint_dir, subdir)
188
        local_trainer.objective = None
189
        local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
190
        # If there hasn't been any evaluation during the training loop.
191
192
193
194
195
        if getattr(local_trainer, "objective", None) is None:
            metrics = local_trainer.evaluate()
            local_trainer.objective = local_trainer.compute_objective(metrics)
            local_trainer._tune_save_checkpoint()
            ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
196

197
198
199
200
201
202
203
204
205
206
    if not trainer._memory_tracker.skip_memory_metrics:
        from .trainer_utils import TrainerMemoryTracker

        logger.warning(
            "Memory tracking for your Trainer is currently "
            "enabled. Automatically disabling the memory tracker "
            "since the memory tracker is not serializable."
        )
        trainer._memory_tracker = TrainerMemoryTracker(skip_memory_metrics=True)

207
208
    # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
    # while doing the ray hp search.
Sylvain Gugger's avatar
Sylvain Gugger committed
209
    _tb_writer = trainer.pop_callback(TensorBoardCallback)
210
    trainer.model = None
211

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
    # Setup default `resources_per_trial`.
    if "resources_per_trial" not in kwargs:
        # Default to 1 CPU and 1 GPU (if applicable) per trial.
        kwargs["resources_per_trial"] = {"cpu": 1}
        if trainer.args.n_gpu > 0:
            kwargs["resources_per_trial"]["gpu"] = 1
        resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
        logger.info(
            "No `resources_per_trial` arg was passed into "
            "`hyperparameter_search`. Setting it to a default value "
            f"of {resource_msg} for each trial."
        )
    # Make sure each trainer only uses GPUs that were allocated per trial.
    gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
    trainer.args._n_gpu = gpus_per_trial
227

228
    # Setup default `progress_reporter`.
229
    if "progress_reporter" not in kwargs:
230
231
232
233
234
235
236
237
        from ray.tune import CLIReporter

        kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
    if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
        # `keep_checkpoints_num=0` would disabled checkpointing
        trainer.use_tune_checkpoints = True
        if kwargs["keep_checkpoints_num"] > 1:
            logger.warning(
238
                f"Currently keeping {kwargs['keep_checkpoints_num']} checkpoints for each trial. "
239
                "Checkpoints are usually huge, "
240
                "consider setting `keep_checkpoints_num=1`."
241
            )
242
243
    if "scheduler" in kwargs:
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
244

245
246
247
248
249
250
251
252
        # Check if checkpointing is enabled for PopulationBasedTraining
        if isinstance(kwargs["scheduler"], PopulationBasedTraining):
            if not trainer.use_tune_checkpoints:
                logger.warning(
                    "You are using PopulationBasedTraining but you haven't enabled checkpointing. "
                    "This means your trials will train from scratch everytime they are exploiting "
                    "new configurations. Consider enabling checkpointing by passing "
                    "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
253
254
                )

255
256
257
        # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
        if isinstance(
            kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
258
        ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
259
260
261
262
263
            raise RuntimeError(
                "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
                "This means your trials will not report intermediate results to Ray Tune, and "
                "can thus not be stopped early or used to exploit other trials parameters. "
                "If this is what you want, do not use {cls}. If you would like to use {cls}, "
264
                "make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
265
266
                "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
            )
267

268
269
270
271
272
    trainable = ray.tune.with_parameters(_objective, local_trainer=trainer)

    @functools.wraps(trainable)
    def dynamic_modules_import_trainable(*args, **kwargs):
        """
273
        Wrapper around `tune.with_parameters` to ensure datasets_modules are loaded on each Actor.
274
275
276

        Without this, an ImportError will be thrown. See https://github.com/huggingface/transformers/issues/11565.

277
        Assumes that `_objective`, defined above, is a function.
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
        """
        if is_datasets_available():
            import datasets.load

            dynamic_modules_path = os.path.join(datasets.load.init_dynamic_modules(), "__init__.py")
            # load dynamic_modules from path
            spec = importlib.util.spec_from_file_location("datasets_modules", dynamic_modules_path)
            datasets_modules = importlib.util.module_from_spec(spec)
            sys.modules[spec.name] = datasets_modules
            spec.loader.exec_module(datasets_modules)
        return trainable(*args, **kwargs)

    # special attr set by tune.with_parameters
    if hasattr(trainable, "__mixins__"):
        dynamic_modules_import_trainable.__mixins__ = trainable.__mixins__

294
    analysis = ray.tune.run(
295
        dynamic_modules_import_trainable,
296
297
298
299
        config=trainer.hp_space(None),
        num_samples=n_trials,
        **kwargs,
    )
300
301
    best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
    best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
Sylvain Gugger's avatar
Sylvain Gugger committed
302
303
    if _tb_writer is not None:
        trainer.add_callback(_tb_writer)
304
    return best_run
Sylvain Gugger's avatar
Sylvain Gugger committed
305
306


307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def run_hp_search_sigopt(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:

    from sigopt import Connection

    conn = Connection()
    proxies = kwargs.pop("proxies", None)
    if proxies is not None:
        conn.set_proxies(proxies)

    experiment = conn.experiments().create(
        name="huggingface-tune",
        parameters=trainer.hp_space(None),
        metrics=[dict(name="objective", objective=direction, strategy="optimize")],
        parallel_bandwidth=1,
        observation_budget=n_trials,
        project="huggingface",
    )
    logger.info(f"created experiment: https://app.sigopt.com/experiment/{experiment.id}")

    while experiment.progress.observation_count < experiment.observation_budget:
        suggestion = conn.experiments(experiment.id).suggestions().create()
        trainer.objective = None
        trainer.train(resume_from_checkpoint=None, trial=suggestion)
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)

        values = [dict(name="objective", value=trainer.objective)]
        obs = conn.experiments(experiment.id).observations().create(suggestion=suggestion.id, values=values)
        logger.info(f"[suggestion_id, observation_id]: [{suggestion.id}, {obs.id}]")
        experiment = conn.experiments(experiment.id).fetch()

    best = list(conn.experiments(experiment.id).best_assignments().fetch().iterate_pages())[0]
    best_run = BestRun(best.id, best.value, best.assignments)

    return best_run


346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
def run_hp_search_wandb(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
    from .integrations import is_wandb_available

    if not is_wandb_available():
        raise ImportError("This function needs wandb installed: `pip install wandb`")
    import wandb

    # add WandbCallback if not already added in trainer callbacks
    reporting_to_wandb = False
    for callback in trainer.callback_handler.callbacks:
        if isinstance(callback, WandbCallback):
            reporting_to_wandb = True
            break
    if not reporting_to_wandb:
        trainer.add_callback(WandbCallback())
    trainer.args.report_to = "wandb"
    best_trial = {"run_id": None, "objective": None, "hyperparameters": None}
    sweep_id = kwargs.pop("sweep_id", None)
    project = kwargs.pop("project", None)
    name = kwargs.pop("name", None)
    entity = kwargs.pop("entity", None)
    metric = kwargs.pop("metric", "eval/loss")

    sweep_config = trainer.hp_space(None)
    sweep_config["metric"]["goal"] = direction
    sweep_config["metric"]["name"] = metric
    if name:
        sweep_config["name"] = name

    def _objective():

        run = wandb.run if wandb.run else wandb.init()
        trainer.state.trial_name = run.name
        run.config.update({"assignments": {}, "metric": metric})
        config = wandb.config

        trainer.objective = None

        trainer.train(resume_from_checkpoint=None, trial=vars(config)["_items"])
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
            format_metrics = rewrite_logs(metrics)
            if metric not in format_metrics:
                logger.warning(
                    f"Provided metric {metric} not found. This might result in unexpected sweeps charts. The available metrics are {format_metrics.keys()}"
                )
        best_score = False
        if best_trial["run_id"] is not None:
            if direction == "minimize":
                best_score = trainer.objective < best_trial["objective"]
            elif direction == "maximize":
                best_score = trainer.objective > best_trial["objective"]

        if best_score or best_trial["run_id"] is None:
            best_trial["run_id"] = run.id
            best_trial["objective"] = trainer.objective
            best_trial["hyperparameters"] = dict(config)

        return trainer.objective

    sweep_id = wandb.sweep(sweep_config, project=project, entity=entity) if not sweep_id else sweep_id
    logger.info(f"wandb sweep id - {sweep_id}")
    wandb.agent(sweep_id, function=_objective, count=n_trials)

    return BestRun(best_trial["run_id"], best_trial["objective"], best_trial["hyperparameters"])


415
416
417
418
419
420
421
422
423
424
425
426
def get_available_reporting_integrations():
    integrations = []
    if is_azureml_available():
        integrations.append("azure_ml")
    if is_comet_available():
        integrations.append("comet_ml")
    if is_mlflow_available():
        integrations.append("mlflow")
    if is_tensorboard_available():
        integrations.append("tensorboard")
    if is_wandb_available():
        integrations.append("wandb")
427
428
    if is_codecarbon_available():
        integrations.append("codecarbon")
429
430
431
    return integrations


432
433
434
435
def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
436
437
    test_prefix = "test_"
    test_prefix_len = len(test_prefix)
438
439
440
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
441
442
        elif k.startswith(test_prefix):
            new_d["test/" + k[test_prefix_len:]] = v
443
444
445
446
447
        else:
            new_d["train/" + k] = v
    return new_d


Sylvain Gugger's avatar
Sylvain Gugger committed
448
449
class TensorBoardCallback(TrainerCallback):
    """
450
    A [`TrainerCallback`] that sends the logs to [TensorBoard](https://www.tensorflow.org/tensorboard).
Sylvain Gugger's avatar
Sylvain Gugger committed
451
452

    Args:
453
        tb_writer (`SummaryWriter`, *optional*):
Tiger's avatar
Tiger committed
454
            The writer to use. Will instantiate one if not set.
Sylvain Gugger's avatar
Sylvain Gugger committed
455
456
457
    """

    def __init__(self, tb_writer=None):
458
        has_tensorboard = is_tensorboard_available()
459
460
461
462
        if not has_tensorboard:
            raise RuntimeError(
                "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
            )
463
464
465
466
467
468
469
470
471
472
473
474
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
max yue's avatar
max yue committed
475
476
        else:
            self._SummaryWriter = None
Sylvain Gugger's avatar
Sylvain Gugger committed
477
478
        self.tb_writer = tb_writer

479
480
    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
481
482
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
483
484

    def on_train_begin(self, args, state, control, **kwargs):
485
486
487
488
489
490
491
492
493
494
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

495
496
        if self.tb_writer is None:
            self._init_summary_writer(args, log_dir)
497

Sylvain Gugger's avatar
Sylvain Gugger committed
498
499
        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
500
501
502
503
504
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
505
506
507
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
Sylvain Gugger's avatar
Sylvain Gugger committed
508
509

    def on_log(self, args, state, control, logs=None, **kwargs):
510
511
512
513
514
        if not state.is_world_process_zero:
            return

        if self.tb_writer is None:
            self._init_summary_writer(args)
515

516
        if self.tb_writer is not None:
517
            logs = rewrite_logs(logs)
Sylvain Gugger's avatar
Sylvain Gugger committed
518
519
520
521
522
523
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
524
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
Sylvain Gugger's avatar
Sylvain Gugger committed
525
                        "This invocation of Tensorboard's writer.add_scalar() "
526
                        "is incorrect so we dropped this attribute."
Sylvain Gugger's avatar
Sylvain Gugger committed
527
528
529
530
531
532
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()
533
            self.tb_writer = None
Sylvain Gugger's avatar
Sylvain Gugger committed
534
535
536
537


class WandbCallback(TrainerCallback):
    """
538
    A [`TrainerCallback`] that sends the logs to [Weight and Biases](https://www.wandb.com/).
Sylvain Gugger's avatar
Sylvain Gugger committed
539
540
541
    """

    def __init__(self):
542
        has_wandb = is_wandb_available()
543
544
        if not has_wandb:
            raise RuntimeError("WandbCallback requires wandb to be installed. Run `pip install wandb`.")
545
546
547
        if has_wandb:
            import wandb

548
            self._wandb = wandb
Sylvain Gugger's avatar
Sylvain Gugger committed
549
        self._initialized = False
550
551
        # log outputs
        self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
Sylvain Gugger's avatar
Sylvain Gugger committed
552

553
    def setup(self, args, state, model, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
554
        """
555
        Setup the optional Weights & Biases (*wandb*) integration.
Sylvain Gugger's avatar
Sylvain Gugger committed
556

Sylvain Gugger's avatar
Sylvain Gugger committed
557
558
559
        One can subclass and override this method to customize the setup if needed. Find more information
        [here](https://docs.wandb.ai/integrations/huggingface). You can also override the following environment
        variables:
Sylvain Gugger's avatar
Sylvain Gugger committed
560
561

        Environment:
562
            WANDB_LOG_MODEL (`bool`, *optional*, defaults to `False`):
563
                Whether or not to log model as artifact at the end of training. Use along with
564
565
                *TrainingArguments.load_best_model_at_end* to upload best model.
            WANDB_WATCH (`str`, *optional* defaults to `"gradients"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
566
567
                Can be `"gradients"`, `"all"` or `"false"`. Set to `"false"` to disable gradient logging or `"all"` to
                log gradients and parameters.
568
            WANDB_PROJECT (`str`, *optional*, defaults to `"huggingface"`):
Sylvain Gugger's avatar
Sylvain Gugger committed
569
                Set this to a custom string to store results in a different project.
570
571
            WANDB_DISABLED (`bool`, *optional*, defaults to `False`):
                Whether or not to disable wandb entirely. Set *WANDB_DISABLED=true* to disable.
Sylvain Gugger's avatar
Sylvain Gugger committed
572
        """
573
574
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
575
576
577
578
579
580
        self._initialized = True
        if state.is_world_process_zero:
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**args.to_sanitized_dict()}
581
582
583
584
585
586
587
588
589
590
591
592

            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            trial_name = state.trial_name
            init_args = {}
            if trial_name is not None:
                run_name = trial_name
                init_args["group"] = args.run_name
            else:
                run_name = args.run_name

593
594
595
596
597
598
599
600
601
602
603
604
605
            if self._wandb.run is None:
                self._wandb.init(
                    project=os.getenv("WANDB_PROJECT", "huggingface"),
                    name=run_name,
                    **init_args,
                )
            # add config parameters (run may have been created manually)
            self._wandb.config.update(combined_dict, allow_val_change=True)

            # define default x-axis (for latest wandb versions)
            if getattr(self._wandb, "define_metric", None):
                self._wandb.define_metric("train/global_step")
                self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
606

Sylvain Gugger's avatar
Sylvain Gugger committed
607
608
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
609
610
611
                self._wandb.watch(
                    model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
612
613

    def on_train_begin(self, args, state, control, model=None, **kwargs):
614
615
        if self._wandb is None:
            return
616
        hp_search = state.is_hyper_param_search
617
618
        if hp_search:
            self._wandb.finish()
619
            self._initialized = False
620
            args.run_name = None
621
622
        if not self._initialized:
            self.setup(args, state, model, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
623

624
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
625
626
        if self._wandb is None:
            return
627
628
629
630
631
632
633
634
635
        if self._log_model and self._initialized and state.is_world_process_zero:
            from .trainer import Trainer

            fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
            with tempfile.TemporaryDirectory() as temp_dir:
                fake_trainer.save_model(temp_dir)
                metadata = (
                    {
                        k: v
636
                        for k, v in dict(self._wandb.summary).items()
637
638
639
640
641
642
643
644
                        if isinstance(v, numbers.Number) and not k.startswith("_")
                    }
                    if not args.load_best_model_at_end
                    else {
                        f"eval/{args.metric_for_best_model}": state.best_metric,
                        "train/total_floss": state.total_flos,
                    }
                )
645
                artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
646
647
648
649
                for f in Path(temp_dir).glob("*"):
                    if f.is_file():
                        with artifact.new_file(f.name, mode="wb") as fa:
                            fa.write(f.read_bytes())
650
                self._wandb.run.log_artifact(artifact)
651

Sylvain Gugger's avatar
Sylvain Gugger committed
652
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
653
654
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
655
        if not self._initialized:
656
            self.setup(args, state, model)
Sylvain Gugger's avatar
Sylvain Gugger committed
657
        if state.is_world_process_zero:
658
            logs = rewrite_logs(logs)
659
            self._wandb.log({**logs, "train/global_step": state.global_step})
Sylvain Gugger's avatar
Sylvain Gugger committed
660
661
662
663


class CometCallback(TrainerCallback):
    """
664
    A [`TrainerCallback`] that sends the logs to [Comet ML](https://www.comet.ml/site/).
Sylvain Gugger's avatar
Sylvain Gugger committed
665
666
667
    """

    def __init__(self):
668
669
        if not _has_comet:
            raise RuntimeError("CometCallback requires comet-ml to be installed. Run `pip install comet-ml`.")
Sylvain Gugger's avatar
Sylvain Gugger committed
670
        self._initialized = False
671
        self._log_assets = False
Sylvain Gugger's avatar
Sylvain Gugger committed
672
673
674
675
676
677

    def setup(self, args, state, model):
        """
        Setup the optional Comet.ml integration.

        Environment:
678
            COMET_MODE (`str`, *optional*):
679
680
                Whether to create an online, offline experiment or disable Comet logging. Can be "OFFLINE", "ONLINE",
                or "DISABLED". Defaults to "ONLINE".
681
            COMET_PROJECT_NAME (`str`, *optional*):
682
                Comet project name for experiments
683
684
685
            COMET_OFFLINE_DIRECTORY (`str`, *optional*):
                Folder to use for saving offline experiments when `COMET_MODE` is "OFFLINE"
            COMET_LOG_ASSETS (`str`, *optional*):
686
687
                Whether or not to log training assets (tf event logs, checkpoints, etc), to Comet. Can be "TRUE", or
                "FALSE". Defaults to "TRUE".
Sylvain Gugger's avatar
Sylvain Gugger committed
688

Sylvain Gugger's avatar
Sylvain Gugger committed
689
690
        For a number of configurable items in the environment, see
        [here](https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables).
Sylvain Gugger's avatar
Sylvain Gugger committed
691
692
        """
        self._initialized = True
693
694
695
        log_assets = os.getenv("COMET_LOG_ASSETS", "FALSE").upper()
        if log_assets in {"TRUE", "1"}:
            self._log_assets = True
Sylvain Gugger's avatar
Sylvain Gugger committed
696
697
698
        if state.is_world_process_zero:
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            experiment = None
699
            experiment_kwargs = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
Sylvain Gugger's avatar
Sylvain Gugger committed
700
            if comet_mode == "ONLINE":
701
702
                experiment = comet_ml.Experiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
703
704
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
705
706
707
                experiment_kwargs["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**experiment_kwargs)
                experiment.log_other("Created from", "transformers")
Sylvain Gugger's avatar
Sylvain Gugger committed
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(model, framework="transformers")
                experiment._log_parameters(args, prefix="args/", framework="transformers")
                if hasattr(model, "config"):
                    experiment._log_parameters(model.config, prefix="config/", framework="transformers")

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if experiment is not None:
                experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
726

727
728
729
730
731
732
733
734
735
736
    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if (experiment is not None) and (self._log_assets is True):
                logger.info("Logging checkpoints. This may take time.")
                experiment.log_asset_folder(
                    args.output_dir, recursive=True, log_file_name=True, step=state.global_step
                )
            experiment.end()

737

738
739
class AzureMLCallback(TrainerCallback):
    """
740
    A [`TrainerCallback`] that sends the logs to [AzureML](https://pypi.org/project/azureml-sdk/).
741
742
743
    """

    def __init__(self, azureml_run=None):
744
745
        if not is_azureml_available():
            raise RuntimeError("AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`.")
746
747
748
        self.azureml_run = azureml_run

    def on_init_end(self, args, state, control, **kwargs):
749
750
        from azureml.core.run import Run

751
752
753
754
        if self.azureml_run is None and state.is_world_process_zero:
            self.azureml_run = Run.get_context()

    def on_log(self, args, state, control, logs=None, **kwargs):
755
        if self.azureml_run and state.is_world_process_zero:
756
757
758
759
760
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.azureml_run.log(k, v, description=k)


761
762
class MLflowCallback(TrainerCallback):
    """
763
764
    A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting
    environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`.
765
766
767
    """

    def __init__(self):
768
769
        if not is_mlflow_available():
            raise RuntimeError("MLflowCallback requires mlflow to be installed. Run `pip install mlflow`.")
770
771
        import mlflow

772
773
774
        self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
        self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH

775
776
        self._initialized = False
        self._log_artifacts = False
777
        self._ml_flow = mlflow
778
779
780
781
782
783

    def setup(self, args, state, model):
        """
        Setup the optional MLflow integration.

        Environment:
784
            HF_MLFLOW_LOG_ARTIFACTS (`str`, *optional*):
785
786
                Whether to use MLflow .log_artifact() facility to log artifacts.

787
                This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or *1*, will copy
Sylvain Gugger's avatar
Sylvain Gugger committed
788
789
                whatever is in [`TrainingArguments`]'s `output_dir` to the local or remote artifact storage. Using it
                without a remote storage will just copy the files to your artifact location.
790
791
792
793
794
        """
        log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
        if log_artifacts in {"TRUE", "1"}:
            self._log_artifacts = True
        if state.is_world_process_zero:
795
796
            if self._ml_flow.active_run is None:
                self._ml_flow.start_run(run_name=args.run_name)
797
798
799
800
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
801
802
803
804
805
806
807
808
809
810
811
            # remove params that are too long for MLflow
            for name, value in list(combined_dict.items()):
                # internally, all values are converted to str in MLflow
                if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
                    logger.warning(
                        f"Trainer is attempting to log a value of "
                        f'"{value}" for key "{name}" as a parameter. '
                        f"MLflow's log_param() only accepts values no longer than "
                        f"250 characters so we dropped this attribute."
                    )
                    del combined_dict[name]
812
813
            # MLflow cannot log more than 100 values in one go, so we have to split it
            combined_dict_items = list(combined_dict.items())
814
815
            for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
                self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
816
817
818
819
820
821
822
823
824
825
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
826
            metrics = {}
827
828
            for k, v in logs.items():
                if isinstance(v, (int, float)):
829
                    metrics[k] = v
830
831
                else:
                    logger.warning(
832
833
834
835
                        f"Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a metric. '
                        f"MLflow's log_metric() only accepts float and "
                        f"int types so we dropped this attribute."
836
                    )
837
            self._ml_flow.log_metrics(metrics=metrics, step=state.global_step)
838
839
840
841
842

    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            if self._log_artifacts:
                logger.info("Logging artifacts. This may take time.")
843
                self._ml_flow.log_artifacts(args.output_dir)
844
845
846
847

    def __del__(self):
        # if the previous run is not terminated correctly, the fluent API will
        # not let you start a new run before the previous one is killed
848
        if self._ml_flow.active_run is not None:
849
            self._ml_flow.end_run()
850
851


852
853
class NeptuneCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
854
    A [`TrainerCallback`] that sends the logs to [Neptune](https://neptune.ai).
855
856
857
    """

    def __init__(self):
858
859
860
861
        if not is_neptune_available():
            raise ValueError(
                "NeptuneCallback requires neptune-client to be installed. Run `pip install neptune-client`."
            )
862
863
864
865
866
867
868
869
870
871
872
        import neptune.new as neptune

        self._neptune = neptune
        self._initialized = False
        self._log_artifacts = False

    def setup(self, args, state, model):
        """
        Setup the Neptune integration.

        Environment:
873
874
875
            NEPTUNE_PROJECT (`str`, *required*):
                The project ID for neptune.ai account. Should be in format *workspace_name/project_name*
            NEPTUNE_API_TOKEN (`str`, *required*):
876
                API-token for neptune.ai account
877
878
879
            NEPTUNE_CONNECTION_MODE (`str`, *optional*):
                Neptune connection mode. *async* by default
            NEPTUNE_RUN_NAME (`str`, *optional*):
880
881
882
883
884
885
886
887
                The name of run process on Neptune dashboard
        """
        if state.is_world_process_zero:
            self._neptune_run = self._neptune.init(
                project=os.getenv("NEPTUNE_PROJECT"),
                api_token=os.getenv("NEPTUNE_API_TOKEN"),
                mode=os.getenv("NEPTUNE_CONNECTION_MODE", "async"),
                name=os.getenv("NEPTUNE_RUN_NAME", None),
888
                run=os.getenv("NEPTUNE_RUN_ID", None),
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
            )
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            self._neptune_run["parameters"] = combined_dict
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            for k, v in logs.items():
                self._neptune_run[k].log(v, step=state.global_step)

    def __del__(self):
        """
        Environment:
911
            NEPTUNE_STOP_TIMEOUT (`int`, *optional*):
912
913
914
915
916
917
918
919
920
921
922
                Number of seconsds to wait for all Neptune.ai tracking calls to finish, before stopping the tracked
                run. If not set it will wait for all tracking calls to finish.
        """
        try:
            stop_timeout = os.getenv("NEPTUNE_STOP_TIMEOUT")
            stop_timeout = int(stop_timeout) if stop_timeout else None
            self._neptune_run.stop(seconds=stop_timeout)
        except AttributeError:
            pass


923
924
class CodeCarbonCallback(TrainerCallback):
    """
925
    A [`TrainerCallback`] that tracks the CO2 emission of training.
926
927
928
    """

    def __init__(self):
929
930
931
932
        if not is_codecarbon_available():
            raise RuntimeError(
                "CodeCarbonCallback requires `codecarbon` to be installed. Run `pip install codecarbon`."
            )
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
        import codecarbon

        self._codecarbon = codecarbon
        self.tracker = None

    def on_init_end(self, args, state, control, **kwargs):
        if self.tracker is None and state.is_local_process_zero:
            # CodeCarbon will automatically handle environment variables for configuration
            self.tracker = self._codecarbon.EmissionsTracker(output_dir=args.output_dir)

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.start()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tracker and state.is_local_process_zero:
            self.tracker.stop()


952
953
954
955
INTEGRATION_TO_CALLBACK = {
    "azure_ml": AzureMLCallback,
    "comet_ml": CometCallback,
    "mlflow": MLflowCallback,
956
    "neptune": NeptuneCallback,
957
958
    "tensorboard": TensorBoardCallback,
    "wandb": WandbCallback,
959
    "codecarbon": CodeCarbonCallback,
960
961
962
963
964
965
966
967
968
969
}


def get_reporting_integration_callbacks(report_to):
    for integration in report_to:
        if integration not in INTEGRATION_TO_CALLBACK:
            raise ValueError(
                f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
            )
    return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]