integrations.py 39.1 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
17
import importlib.util
18
19
import io
import json
20
import numbers
21
import os
22
import sys
23
import tempfile
24
from copy import deepcopy
25
from pathlib import Path
26

27
from .utils import logging
28
from .utils.versions import require_version
29
30
31
32


logger = logging.get_logger(__name__)

Sylvain Gugger's avatar
Sylvain Gugger committed
33

34
# comet_ml requires to be imported before any ML frameworks
35
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
36
if _has_comet:
37
    try:
38
        import comet_ml  # noqa: F401
39

40
41
42
43
44
45
46
47
        if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
            _has_comet = True
        else:
            if os.getenv("COMET_MODE", "").upper() != "DISABLED":
                logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
            _has_comet = False
    except (ImportError, ValueError):
        _has_comet = False
48

49
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available  # noqa: E402
50
from .trainer_callback import TrainerCallback  # noqa: E402
51
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, IntervalStrategy  # noqa: E402
52
53


Sylvain Gugger's avatar
Sylvain Gugger committed
54
# Integration functions:
55
def is_wandb_available():
56
    # any value of WANDB_DISABLED disables wandb
57
    if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
58
        logger.warning(
59
60
61
            "Using the `WAND_DISABLED` environment variable is deprecated and will be removed in v5. Use the "
            "--report_to flag to control the integrations used for logging result (for instance --report_to none)."
        )
62
63
        return False
    return importlib.util.find_spec("wandb") is not None
64
65
66
67
68
69
70


def is_comet_available():
    return _has_comet


def is_tensorboard_available():
71
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
72
73
74


def is_optuna_available():
75
    return importlib.util.find_spec("optuna") is not None
76
77
78


def is_ray_available():
79
    return importlib.util.find_spec("ray") is not None
80
81


82
def is_ray_tune_available():
83
84
85
    if not is_ray_available():
        return False
    return importlib.util.find_spec("ray.tune") is not None
86
87


88
def is_azureml_available():
89
90
91
92
93
    if importlib.util.find_spec("azureml") is None:
        return False
    if importlib.util.find_spec("azureml.core") is None:
        return False
    return importlib.util.find_spec("azureml.core.run") is not None
94
95


96
def is_mlflow_available():
97
    return importlib.util.find_spec("mlflow") is not None
98
99


100
def is_fairscale_available():
101
    return importlib.util.find_spec("fairscale") is not None
102
103


104
105
106
107
def is_deepspeed_available():
    return importlib.util.find_spec("deepspeed") is not None


108
109
def hp_params(trial):
    if is_optuna_available():
110
111
        import optuna

112
113
        if isinstance(trial, optuna.Trial):
            return trial.params
114
    if is_ray_tune_available():
115
116
117
118
119
120
        if isinstance(trial, dict):
            return trial

    raise RuntimeError(f"Unknown type for trial {trial.__class__}")


121
122
123
def default_hp_search_backend():
    if is_optuna_available():
        return "optuna"
124
    elif is_ray_tune_available():
125
        return "ray"
126
127


128
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
129
130
    import optuna

131
    def _objective(trial, checkpoint_dir=None):
132
        checkpoint = None
133
134
135
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
136
                    checkpoint = os.path.join(checkpoint_dir, subdir)
137
        trainer.objective = None
138
        trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
139
140
141
142
143
144
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

145
146
147
148
149
150
151
152
153
    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
154
155
    import ray

156
    def _objective(trial, local_trainer, checkpoint_dir=None):
157
        checkpoint = None
158
159
160
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
161
                    checkpoint = os.path.join(checkpoint_dir, subdir)
162
        local_trainer.objective = None
163
        local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
164
        # If there hasn't been any evaluation during the training loop.
165
166
167
168
169
        if getattr(local_trainer, "objective", None) is None:
            metrics = local_trainer.evaluate()
            local_trainer.objective = local_trainer.compute_objective(metrics)
            local_trainer._tune_save_checkpoint()
            ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
170
171
172

    # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
    # while doing the ray hp search.
Sylvain Gugger's avatar
Sylvain Gugger committed
173
174

    _tb_writer = trainer.pop_callback(TensorBoardCallback)
175
    trainer.model = None
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    # Setup default `resources_per_trial`.
    if "resources_per_trial" not in kwargs:
        # Default to 1 CPU and 1 GPU (if applicable) per trial.
        kwargs["resources_per_trial"] = {"cpu": 1}
        if trainer.args.n_gpu > 0:
            kwargs["resources_per_trial"]["gpu"] = 1
        resource_msg = "1 CPU" + (" and 1 GPU" if trainer.args.n_gpu > 0 else "")
        logger.info(
            "No `resources_per_trial` arg was passed into "
            "`hyperparameter_search`. Setting it to a default value "
            f"of {resource_msg} for each trial."
        )
    # Make sure each trainer only uses GPUs that were allocated per trial.
    gpus_per_trial = kwargs["resources_per_trial"].get("gpu", 0)
    trainer.args._n_gpu = gpus_per_trial
191

192
    # Setup default `progress_reporter`.
193
    if "progress_reporter" not in kwargs:
194
195
196
197
198
199
200
201
        from ray.tune import CLIReporter

        kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
    if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
        # `keep_checkpoints_num=0` would disabled checkpointing
        trainer.use_tune_checkpoints = True
        if kwargs["keep_checkpoints_num"] > 1:
            logger.warning(
202
203
                f"Currently keeping {kwargs['keep_checkpoint_num']} checkpoints for each trial. "
                "Checkpoints are usually huge, "
204
                "consider setting `keep_checkpoints_num=1`."
205
            )
206
207
    if "scheduler" in kwargs:
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
208

209
210
211
212
213
214
215
216
        # Check if checkpointing is enabled for PopulationBasedTraining
        if isinstance(kwargs["scheduler"], PopulationBasedTraining):
            if not trainer.use_tune_checkpoints:
                logger.warning(
                    "You are using PopulationBasedTraining but you haven't enabled checkpointing. "
                    "This means your trials will train from scratch everytime they are exploiting "
                    "new configurations. Consider enabling checkpointing by passing "
                    "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
217
218
                )

219
220
221
        # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
        if isinstance(
            kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
222
        ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == IntervalStrategy.NO):
223
224
225
226
227
            raise RuntimeError(
                "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
                "This means your trials will not report intermediate results to Ray Tune, and "
                "can thus not be stopped early or used to exploit other trials parameters. "
                "If this is what you want, do not use {cls}. If you would like to use {cls}, "
228
                "make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
229
230
                "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
            )
231

232
233
234
235
236
237
    analysis = ray.tune.run(
        ray.tune.with_parameters(_objective, local_trainer=trainer),
        config=trainer.hp_space(None),
        num_samples=n_trials,
        **kwargs,
    )
238
239
    best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
    best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
Sylvain Gugger's avatar
Sylvain Gugger committed
240
241
    if _tb_writer is not None:
        trainer.add_callback(_tb_writer)
242
    return best_run
Sylvain Gugger's avatar
Sylvain Gugger committed
243
244


245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def get_available_reporting_integrations():
    integrations = []
    if is_azureml_available():
        integrations.append("azure_ml")
    if is_comet_available():
        integrations.append("comet_ml")
    if is_mlflow_available():
        integrations.append("mlflow")
    if is_tensorboard_available():
        integrations.append("tensorboard")
    if is_wandb_available():
        integrations.append("wandb")
    return integrations


260
261
262
263
264
265
266
267
268
269
270
271
def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
        else:
            new_d["train/" + k] = v
    return new_d


272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
_is_deepspeed_zero3_enabled = None


def is_deepspeed_zero3_enabled():
    """
    This function answers to the question of whether DeepSpeed is going to be used and run using ZeRO Stage 3.

    It includes an auto-discovery method, see comments in the code for details.

    Returns: ``True`` if either it was explicitly enabled via ``deepspeed_zero3_enable(True)`` or the auto-detector was
    able to derive that the ``Trainer`` will be running via DeepSpeed ZeRO stage 3.
    """
    global _is_deepspeed_zero3_enabled
    if _is_deepspeed_zero3_enabled is None:
        _is_deepspeed_zero3_enabled = False
        # Try to auto-discover if we are about to use DeepSpeed with ZeRO3 enabled. This will only
        # work for scripts using cli to pass --deepspeed ds_config.json. If cmd args aren't used,
        # then to get the model efficiently loaded across multiple-gpus one has to explicitly call
        # is_deepspeed_zero3_enabled(True) **before** instantiating a model object
        if "--deepspeed" in sys.argv:
            idx = sys.argv.index("--deepspeed")
            ds_config = sys.argv[idx + 1]
            if not os.path.exists(ds_config):
                raise ValueError("--deepspeed requires a valid path to a config file")
            config = deepspeed_parse_config(ds_config)
            if (
                "zero_optimization" in config
                and "stage" in config["zero_optimization"]
                and config["zero_optimization"]["stage"] == 3
            ):
                _is_deepspeed_zero3_enabled = True

    return _is_deepspeed_zero3_enabled


def deepspeed_zero3_enable(enable=True):
    """
    ``is_deepspeed_zero3_enabled()`` tries to derive automatically if DeepSpeed ZeRO 3 is going to be used by looking
    at ``sys.argv`` which may or may contain information about where to find the DeepSpeed config if any.

    This function allows for explicit enabling/disabling of this global flag.

    Args:
        enable: if set to ``True`` will make ``is_deepspeed_zero3_enabled()`` return ``True``
    """
    global _is_deepspeed_zero3_enabled
    _is_deepspeed_zero3_enabled = enable


def deepspeed_parse_config(ds_config):
    """
    If ``ds_config`` isn't already a dict, read it from the config file.

    If it's already a dict, return a copy of it, so that we can freely modify it.
    """
    require_version("deepspeed>0.3.13")

    if isinstance(ds_config, dict):
        # Don't modify user's data should they want to reuse it (e.g. in tests), because once we
        # modified it, it will not be accepted here again, since some config params must be not set by users
        config = deepcopy(ds_config)
    elif isinstance(ds_config, str):
        with io.open(ds_config, "r", encoding="utf-8") as f:
            config = json.load(f)
    else:
        raise ValueError("expecting either a path to a config file or a pre-populated dict")

    return config


def deepspeed_init(trainer, num_training_steps, resume_from_checkpoint=None):
343
    """
344
345
346
    Init DeepSpeed, after updating the DeepSpeed configuration with any relevant Trainer's args.

    If ``resume_from_checkpoint`` was passed then an attempt to resume from a previously saved checkpoint will be made.
347
348
349
350

    Args:
        trainer: Trainer object
        num_training_steps: per single gpu
351
        resume_from_checkpoint: path to a checkpoint if to resume from after normal DeepSpeedEngine load
352
353

    Returns: model, optimizer, lr_scheduler
354

355
356
357
358
359
360
    """
    import deepspeed

    args = trainer.args
    model = trainer.model

361
    config = deepspeed_parse_config(args.deepspeed)
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386

    # The following code translates relevant trainer's cl args into the DS config

    # First to ensure that there is no mismatch between cl args values and presets in the config
    # file, ask to not set in ds config file:
    # - "train_batch_size",
    # - "train_micro_batch_size_per_gpu",
    # - "gradient_accumulation_steps"
    bs_keys = ["train_batch_size", "train_micro_batch_size_per_gpu"]
    if len([x for x in bs_keys if x in config.keys()]):
        raise ValueError(
            f"Do not include {bs_keys} entries in the ds config file, as they will be set via --per_device_train_batch_size or its default"
        )
    if "gradient_accumulation_steps" in config.keys():
        raise ValueError(
            "Do not include gradient_accumulation_steps entries in the ds config file, as they will be set via --gradient_accumulation_steps or its default"
        )

    # DeepSpeed does:
    #   train_batch_size = n_gpus * train_micro_batch_size_per_gpu * gradient_accumulation_steps
    # therefore we just need to set:
    config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
    config["gradient_accumulation_steps"] = args.gradient_accumulation_steps

    if "gradient_clipping" in config:
387
        logger.info("Keeping the `gradient_clipping` config intact, ignoring any gradient clipping-specific cl args")
388
389
390
    else:  # override only if the ds config doesn't already have this section
        config["gradient_clipping"] = args.max_grad_norm

391
392
393
394
395
396
    # Optimizer + Scheduler
    # Currently support combos:
    # 1. DS scheduler + DS optimizer: Yes
    # 2. HF scheduler + HF optimizer: Yes
    # 3. DS scheduler + HF optimizer: Yes
    # 4. HF scheduler + DS optimizer: No
397
    #
398
399
400
401
402
403
404
    # Unless Offload is enabled in which case it's:
    # 1. DS scheduler + DS optimizer: Yes
    # 2. HF scheduler + HF optimizer: No
    # 3. DS scheduler + HF optimizer: No
    # 4. HF scheduler + DS optimizer: No

    optimizer = None
405
    if "optimizer" in config:
406
        logger.info("Updating the `scheduler` config with other command line arguments")
407
408
409
410
411
412
413

        # to avoid inconsistent values of lr and warm up steps the command line args override config
        params = dict(
            lr=args.learning_rate,
            betas=[args.adam_beta1, args.adam_beta2],
            eps=args.adam_epsilon,
            weight_decay=args.weight_decay,
414
        )
415
416
417
418
419
        for k, v in params.items():
            if k in config["optimizer"]["params"]:
                logger.info(f"setting optimizer.params.{k} to {v}")
                config["optimizer"]["params"][k] = v

420
    else:  # override only if the ds config doesn't already have this section
421
422
423
424
425
426
427
428
429
430
431
432
433
434
        if (
            "zero_optimization" in config
            and "cpu_offload" in config["zero_optimization"]
            and config["zero_optimization"]["cpu_offload"] is True
        ):
            raise ValueError("ZeRO Offload can only work with DeepSpeed optimizers")
        else:
            # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
            # But trainer uses AdamW by default.
            # To use other optimizers so using a different scheduler requires voiding warranty with: `zero_allow_untested_optimizer`
            trainer.create_optimizer()
            optimizer = trainer.optimizer
            # flag that this is non-native optimizer
            config["zero_allow_untested_optimizer"] = True
435
436
437
438
439
440
441
442
443

    # DS schedulers (deepspeed/runtime/lr_schedules.py):
    #
    # DS name      | --lr_scheduler_type  | HF func                           | Notes
    # -------------| ---------------------|-----------------------------------|--------------------
    # LRRangeTest  | na                   | na                                | LRRT
    # OneCycle     | na                   | na                                | 1CLR
    # WarmupLR     | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
    # WarmupDecayLR| linear               | get_linear_schedule_with_warmup   |
444
    lr_scheduler = None
445
    if "scheduler" in config:
446
        logger.info("Updating the `scheduler` config with other command line arguments")
447
448
449
450
451
452
453
454
455
456
        # the user won't easily know the correct num_training_steps should they use WarmupDecayLR,
        # so let's set it to the correct value
        if config["scheduler"]["type"] == "WarmupDecayLR":
            logger.info(f"setting scheduler.params.total_num_steps to {num_training_steps}")
            config["scheduler"]["params"]["total_num_steps"] = num_training_steps

        # to avoid inconsistent values of lr and warmup steps the command line args override config
        params = dict(
            warmup_max_lr=args.learning_rate,
            warmup_num_steps=args.warmup_steps,
457
        )
458
459
460
461
462
        for k, v in params.items():
            if k in config["scheduler"]["params"]:
                logger.info(f"setting scheduler.params.{k} to {v}")
                config["scheduler"]["params"][k] = v

463
    else:  # override only if the ds config doesn't already have this section
464
465
466
467
        if "optimizer" in config:
            # to make this option work, we need to init DS optimizer first, then init HS scheduler,
            # then pass the HS scheduler to DS init, which is not possible at the moment
            raise ValueError("At the moment HF scheduler + DeepSpeed optimizer combination is not possible")
468
        else:
469
470
            trainer.create_scheduler(num_training_steps=num_training_steps)
            lr_scheduler = trainer.lr_scheduler
471
472
473
474
475
476
477
478

    # fp16
    if trainer.fp16_backend is not None:
        # Deepspeed has 2 possible fp16 config entries:
        # - `fp16`: for the native amp - it has a bunch of optional params but we won't set any here unless the user did the work
        # - `amp`: which delegates amp work to apex (which needs to be available), but it cannot be used with any ZeRO features, so probably best to be avoided.
        if trainer.fp16_backend == "apex":
            if "amp" in config:
479
                logger.info("Keeping the `amp` config intact, ignoring any amp-specific cl args")
480
481
482
483
484
485
486
            else:
                config["amp"] = {
                    "enabled": True,
                    "opt_level": args.fp16_opt_level,
                }
        elif trainer.fp16_backend == "amp":
            if "fp16" in config:
487
                logger.info("Keeping the `fp16` config intact, ignoring any fp16-specific cl args")
488
489
490
491
492
            else:
                config["fp16"] = {
                    "enabled": True,
                }

493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
    # zero
    if "zero_optimization" in config:
        zero = config["zero_optimization"]

        # now we know for sure if zero3 is enabled
        deepspeed_zero3_enable(zero.get("stage") == 3)

        # automatically assign the optimal config values based on model config
        hidden_size = model.config.hidden_size
        if zero.get("reduce_bucket_size") == 0:
            zero["reduce_bucket_size"] = hidden_size * hidden_size
        if zero.get("stage3_prefetch_bucket_size") == 0:
            zero["stage3_prefetch_bucket_size"] = 0.9 * hidden_size * hidden_size
        if zero.get("stage3_param_persistence_threshold") == 0:
            zero["stage3_param_persistence_threshold"] = 10 * hidden_size

509
510
511
    # keep for quick debug:
    # from pprint import pprint; pprint(config)

512
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
513

514
515
516
517
    model, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        model_parameters=model_parameters,
        config_params=config,
518
519
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
520
521
    )

522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    if resume_from_checkpoint is not None:

        # it's possible that the user is trying to resume from model_path, which doesn't necessarily
        # contain a deepspeed checkpoint. e.g. examples just check if the dir exists and assume it's
        # a resume from a checkpoint and not just a local pretrained weight. So we check here if the
        # path contains what looks like a deepspeed checkpoint
        import glob

        deepspeed_checkpoint_dirs = sorted(glob.glob(f"{resume_from_checkpoint}/global_step*"))

        if len(deepspeed_checkpoint_dirs) > 0:
            logger.info(f"Attempting to resume from {resume_from_checkpoint}")
            # this magically updates self.optimizer and self.lr_scheduler
            load_path, _ = model.load_checkpoint(
                resume_from_checkpoint, load_optimizer_states=True, load_lr_scheduler_states=True
            )
            if load_path is None:
                raise ValueError(f"[deepspeed] failed to resume from checkpoint {resume_from_checkpoint}")
        else:
            logger.info(f"{resume_from_checkpoint} doesn't have deepspeed checkpoints, doing nothing")
542

543
544
545
    return model, optimizer, lr_scheduler


Sylvain Gugger's avatar
Sylvain Gugger committed
546
547
548
549
550
551
552
class TensorBoardCallback(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
    <https://www.tensorflow.org/tensorboard>`__.

    Args:
        tb_writer (:obj:`SummaryWriter`, `optional`):
Tiger's avatar
Tiger committed
553
            The writer to use. Will instantiate one if not set.
Sylvain Gugger's avatar
Sylvain Gugger committed
554
555
556
    """

    def __init__(self, tb_writer=None):
557
        has_tensorboard = is_tensorboard_available()
Sylvain Gugger's avatar
Sylvain Gugger committed
558
        assert (
559
            has_tensorboard
Sylvain Gugger's avatar
Sylvain Gugger committed
560
        ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
561
562
563
564
565
566
567
568
569
570
571
572
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
max yue's avatar
max yue committed
573
574
        else:
            self._SummaryWriter = None
Sylvain Gugger's avatar
Sylvain Gugger committed
575
576
        self.tb_writer = tb_writer

577
578
    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
579
580
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
581
582

    def on_train_begin(self, args, state, control, **kwargs):
583
584
585
586
587
588
589
590
591
592
593
594
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        self._init_summary_writer(args, log_dir)

Sylvain Gugger's avatar
Sylvain Gugger committed
595
596
        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
597
598
599
600
601
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
602
603
604
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
Sylvain Gugger's avatar
Sylvain Gugger committed
605
606

    def on_log(self, args, state, control, logs=None, **kwargs):
607
608
609
610
        if state.is_world_process_zero:
            if self.tb_writer is None:
                self._init_summary_writer(args)

611
        if self.tb_writer is not None:
612
            logs = rewrite_logs(logs)
Sylvain Gugger's avatar
Sylvain Gugger committed
613
614
615
616
617
618
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
619
                        f'"{v}" of type {type(v)} for key "{k}" as a scalar. '
Sylvain Gugger's avatar
Sylvain Gugger committed
620
                        "This invocation of Tensorboard's writer.add_scalar() "
621
                        "is incorrect so we dropped this attribute."
Sylvain Gugger's avatar
Sylvain Gugger committed
622
623
624
625
626
627
628
629
630
631
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()


class WandbCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
632
    A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases <https://www.wandb.com/>`__.
Sylvain Gugger's avatar
Sylvain Gugger committed
633
634
635
    """

    def __init__(self):
636
637
638
639
640
        has_wandb = is_wandb_available()
        assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
        if has_wandb:
            import wandb

641
            self._wandb = wandb
Sylvain Gugger's avatar
Sylvain Gugger committed
642
        self._initialized = False
643
644
        # log outputs
        self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
Sylvain Gugger's avatar
Sylvain Gugger committed
645

646
    def setup(self, args, state, model, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
647
648
649
        """
        Setup the optional Weights & Biases (`wandb`) integration.

Sylvain Gugger's avatar
Sylvain Gugger committed
650
        One can subclass and override this method to customize the setup if needed. Find more information `here
651
        <https://docs.wandb.ai/integrations/huggingface>`__. You can also override the following environment variables:
Sylvain Gugger's avatar
Sylvain Gugger committed
652
653

        Environment:
654
            WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
655
656
                Whether or not to log model as artifact at the end of training. Use along with
                `TrainingArguments.load_best_model_at_end` to upload best model.
Sylvain Gugger's avatar
Sylvain Gugger committed
657
658
659
660
661
662
            WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
                Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
                logging or :obj:`"all"` to log gradients and parameters.
            WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
                Set this to a custom string to store results in a different project.
            WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
663
                Whether or not to disable wandb entirely. Set `WANDB_DISABLED=true` to disable.
Sylvain Gugger's avatar
Sylvain Gugger committed
664
        """
665
666
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
667
668
669
670
671
672
        self._initialized = True
        if state.is_world_process_zero:
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**args.to_sanitized_dict()}
673
674
675
676
677
678
679
680
681
682
683
684

            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            trial_name = state.trial_name
            init_args = {}
            if trial_name is not None:
                run_name = trial_name
                init_args["group"] = args.run_name
            else:
                run_name = args.run_name

685
686
687
688
689
690
691
692
693
694
695
696
697
            if self._wandb.run is None:
                self._wandb.init(
                    project=os.getenv("WANDB_PROJECT", "huggingface"),
                    name=run_name,
                    **init_args,
                )
            # add config parameters (run may have been created manually)
            self._wandb.config.update(combined_dict, allow_val_change=True)

            # define default x-axis (for latest wandb versions)
            if getattr(self._wandb, "define_metric", None):
                self._wandb.define_metric("train/global_step")
                self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
698

Sylvain Gugger's avatar
Sylvain Gugger committed
699
700
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
701
702
703
                self._wandb.watch(
                    model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
704
705

    def on_train_begin(self, args, state, control, model=None, **kwargs):
706
707
        if self._wandb is None:
            return
708
        hp_search = state.is_hyper_param_search
709
710
711
712
        if hp_search:
            self._wandb.finish()
        if not self._initialized:
            self.setup(args, state, model, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
713

714
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
715
716
        if self._wandb is None:
            return
717
718
719
720
721
722
723
724
725
        if self._log_model and self._initialized and state.is_world_process_zero:
            from .trainer import Trainer

            fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
            with tempfile.TemporaryDirectory() as temp_dir:
                fake_trainer.save_model(temp_dir)
                metadata = (
                    {
                        k: v
726
                        for k, v in dict(self._wandb.summary).items()
727
728
729
730
731
732
733
734
                        if isinstance(v, numbers.Number) and not k.startswith("_")
                    }
                    if not args.load_best_model_at_end
                    else {
                        f"eval/{args.metric_for_best_model}": state.best_metric,
                        "train/total_floss": state.total_flos,
                    }
                )
735
                artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}", type="model", metadata=metadata)
736
737
738
739
                for f in Path(temp_dir).glob("*"):
                    if f.is_file():
                        with artifact.new_file(f.name, mode="wb") as fa:
                            fa.write(f.read_bytes())
740
                self._wandb.run.log_artifact(artifact)
741

Sylvain Gugger's avatar
Sylvain Gugger committed
742
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
743
744
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
745
        if not self._initialized:
746
            self.setup(args, state, model)
Sylvain Gugger's avatar
Sylvain Gugger committed
747
        if state.is_world_process_zero:
748
            logs = rewrite_logs(logs)
749
            self._wandb.log({**logs, "train/global_step": state.global_step})
Sylvain Gugger's avatar
Sylvain Gugger committed
750
751
752
753


class CometCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
754
    A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML <https://www.comet.ml/site/>`__.
Sylvain Gugger's avatar
Sylvain Gugger committed
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
    """

    def __init__(self):
        assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
        self._initialized = False

    def setup(self, args, state, model):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE (:obj:`str`, `optional`):
                "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME (:obj:`str`, `optional`):
                Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
                Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"

Sylvain Gugger's avatar
Sylvain Gugger committed
773
774
        For a number of configurable items in the environment, see `here
        <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
Sylvain Gugger's avatar
Sylvain Gugger committed
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
        """
        self._initialized = True
        if state.is_world_process_zero:
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(model, framework="transformers")
                experiment._log_parameters(args, prefix="args/", framework="transformers")
                if hasattr(model, "config"):
                    experiment._log_parameters(model.config, prefix="config/", framework="transformers")

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if experiment is not None:
                experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
805
806


807
808
809
810
811
812
813
class AzureMLCallback(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that sends the logs to `AzureML
    <https://pypi.org/project/azureml-sdk/>`__.
    """

    def __init__(self, azureml_run=None):
814
815
816
        assert (
            is_azureml_available()
        ), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
817
818
819
        self.azureml_run = azureml_run

    def on_init_end(self, args, state, control, **kwargs):
820
821
        from azureml.core.run import Run

822
823
824
825
826
827
828
829
830
831
        if self.azureml_run is None and state.is_world_process_zero:
            self.azureml_run = Run.get_context()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if self.azureml_run:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.azureml_run.log(k, v, description=k)


832
833
class MLflowCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
834
    A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
835
836
837
    """

    def __init__(self):
838
839
840
        assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
        import mlflow

841
842
843
        self._MAX_PARAM_VAL_LENGTH = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH
        self._MAX_PARAMS_TAGS_PER_BATCH = mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH

844
845
        self._initialized = False
        self._log_artifacts = False
846
        self._ml_flow = mlflow
847
848
849
850
851
852
853
854
855

    def setup(self, args, state, model):
        """
        Setup the optional MLflow integration.

        Environment:
            HF_MLFLOW_LOG_ARTIFACTS (:obj:`str`, `optional`):
                Whether to use MLflow .log_artifact() facility to log artifacts.

Sylvain Gugger's avatar
Sylvain Gugger committed
856
857
858
                This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or `1`, will copy
                whatever is in TrainerArgument's output_dir to the local or remote artifact storage. Using it without a
                remote storage will just copy the files to your artifact location.
859
860
861
862
863
        """
        log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
        if log_artifacts in {"TRUE", "1"}:
            self._log_artifacts = True
        if state.is_world_process_zero:
864
            self._ml_flow.start_run()
865
866
867
868
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
869
870
871
872
873
874
875
876
877
878
879
            # remove params that are too long for MLflow
            for name, value in list(combined_dict.items()):
                # internally, all values are converted to str in MLflow
                if len(str(value)) > self._MAX_PARAM_VAL_LENGTH:
                    logger.warning(
                        f"Trainer is attempting to log a value of "
                        f'"{value}" for key "{name}" as a parameter. '
                        f"MLflow's log_param() only accepts values no longer than "
                        f"250 characters so we dropped this attribute."
                    )
                    del combined_dict[name]
880
881
            # MLflow cannot log more than 100 values in one go, so we have to split it
            combined_dict_items = list(combined_dict.items())
882
883
            for i in range(0, len(combined_dict_items), self._MAX_PARAMS_TAGS_PER_BATCH):
                self._ml_flow.log_params(dict(combined_dict_items[i : i + self._MAX_PARAMS_TAGS_PER_BATCH]))
884
885
886
887
888
889
890
891
892
893
894
895
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
896
                    self._ml_flow.log_metric(k, v, step=state.global_step)
897
898
                else:
                    logger.warning(
899
900
901
902
                        f"Trainer is attempting to log a value of "
                        f'"{v}" of type {type(v)} for key "{k}" as a metric. '
                        f"MLflow's log_metric() only accepts float and "
                        f"int types so we dropped this attribute."
903
904
905
906
907
908
                    )

    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            if self._log_artifacts:
                logger.info("Logging artifacts. This may take time.")
909
                self._ml_flow.log_artifacts(args.output_dir)
910
911
912
913

    def __del__(self):
        # if the previous run is not terminated correctly, the fluent API will
        # not let you start a new run before the previous one is killed
914
        if self._ml_flow.active_run is not None:
915
            self._ml_flow.end_run()
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933


INTEGRATION_TO_CALLBACK = {
    "azure_ml": AzureMLCallback,
    "comet_ml": CometCallback,
    "mlflow": MLflowCallback,
    "tensorboard": TensorBoardCallback,
    "wandb": WandbCallback,
}


def get_reporting_integration_callbacks(report_to):
    for integration in report_to:
        if integration not in INTEGRATION_TO_CALLBACK:
            raise ValueError(
                f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
            )
    return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]