integrations.py 32.1 KB
Newer Older
Sylvain Gugger's avatar
Sylvain Gugger committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Integrations with other Python libraries.
"""
17
import importlib.util
18
19
import io
import json
20
import math
21
import numbers
22
import os
23
24
25
import re
import tempfile
from pathlib import Path
26
from types import SimpleNamespace
27

28
from .trainer_utils import SchedulerType
29
30
31
32
33
from .utils import logging


logger = logging.get_logger(__name__)

Sylvain Gugger's avatar
Sylvain Gugger committed
34

35
# comet_ml requires to be imported before any ML frameworks
36
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
37
if _has_comet:
38
    try:
39
        import comet_ml  # noqa: F401
40

41
42
43
44
45
46
47
48
        if hasattr(comet_ml, "config") and comet_ml.config.get_config("comet.api_key"):
            _has_comet = True
        else:
            if os.getenv("COMET_MODE", "").upper() != "DISABLED":
                logger.warning("comet_ml is installed but `COMET_API_KEY` is not set.")
            _has_comet = False
    except (ImportError, ValueError):
        _has_comet = False
49

50
from .file_utils import ENV_VARS_TRUE_VALUES, is_torch_tpu_available  # noqa: E402
51
from .trainer_callback import TrainerCallback  # noqa: E402
52
from .trainer_utils import PREFIX_CHECKPOINT_DIR, BestRun, EvaluationStrategy  # noqa: E402
53
54


Sylvain Gugger's avatar
Sylvain Gugger committed
55
# Integration functions:
56
def is_wandb_available():
57
    if os.getenv("WANDB_DISABLED", "").upper() in ENV_VARS_TRUE_VALUES:
58
59
        return False
    return importlib.util.find_spec("wandb") is not None
60
61
62
63
64
65
66


def is_comet_available():
    return _has_comet


def is_tensorboard_available():
67
    return importlib.util.find_spec("tensorboard") is not None or importlib.util.find_spec("tensorboardX") is not None
68
69
70


def is_optuna_available():
71
    return importlib.util.find_spec("optuna") is not None
72
73
74


def is_ray_available():
75
    return importlib.util.find_spec("ray") is not None
76
77


78
def is_ray_tune_available():
79
80
81
    if not is_ray_available():
        return False
    return importlib.util.find_spec("ray.tune") is not None
82
83


84
def is_azureml_available():
85
86
87
88
89
    if importlib.util.find_spec("azureml") is None:
        return False
    if importlib.util.find_spec("azureml.core") is None:
        return False
    return importlib.util.find_spec("azureml.core.run") is not None
90
91


92
def is_mlflow_available():
93
    return importlib.util.find_spec("mlflow") is not None
94
95


96
def is_fairscale_available():
97
    return importlib.util.find_spec("fairscale") is not None
98
99


100
101
102
103
def is_deepspeed_available():
    return importlib.util.find_spec("deepspeed") is not None


104
105
def hp_params(trial):
    if is_optuna_available():
106
107
        import optuna

108
109
        if isinstance(trial, optuna.Trial):
            return trial.params
110
    if is_ray_tune_available():
111
112
113
114
115
116
        if isinstance(trial, dict):
            return trial

    raise RuntimeError(f"Unknown type for trial {trial.__class__}")


117
118
119
def default_hp_search_backend():
    if is_optuna_available():
        return "optuna"
120
    elif is_ray_tune_available():
121
        return "ray"
122
123


124
def run_hp_search_optuna(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
125
126
    import optuna

127
    def _objective(trial, checkpoint_dir=None):
128
        checkpoint = None
129
130
131
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
132
                    checkpoint = os.path.join(checkpoint_dir, subdir)
133
        trainer.objective = None
134
        trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
135
136
137
138
139
140
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

141
142
143
144
145
146
147
148
149
    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, timeout=timeout, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params)


def run_hp_search_ray(trainer, n_trials: int, direction: str, **kwargs) -> BestRun:
150
151
    import ray

152
    def _objective(trial, local_trainer, checkpoint_dir=None):
153
        checkpoint = None
154
155
156
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
157
                    checkpoint = os.path.join(checkpoint_dir, subdir)
158
        local_trainer.objective = None
159
        local_trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
160
        # If there hasn't been any evaluation during the training loop.
161
162
163
164
165
        if getattr(local_trainer, "objective", None) is None:
            metrics = local_trainer.evaluate()
            local_trainer.objective = local_trainer.compute_objective(metrics)
            local_trainer._tune_save_checkpoint()
            ray.tune.report(objective=local_trainer.objective, **metrics, done=True)
166
167
168

    # The model and TensorBoard writer do not pickle so we have to remove them (if they exists)
    # while doing the ray hp search.
Sylvain Gugger's avatar
Sylvain Gugger committed
169
170

    _tb_writer = trainer.pop_callback(TensorBoardCallback)
171
172
173
174
175
176
177
178
    trainer.model = None
    # Setup default `resources_per_trial` and `reporter`.
    if "resources_per_trial" not in kwargs and trainer.args.n_gpu > 0:
        # `args.n_gpu` is considered the total number of GPUs that will be split
        # among the `n_jobs`
        n_jobs = int(kwargs.pop("n_jobs", 1))
        num_gpus_per_trial = trainer.args.n_gpu
        if num_gpus_per_trial / n_jobs >= 1:
179
            num_gpus_per_trial = int(math.ceil(num_gpus_per_trial / n_jobs))
180
181
        kwargs["resources_per_trial"] = {"gpu": num_gpus_per_trial}

182
    if "progress_reporter" not in kwargs:
183
184
185
186
187
188
189
190
191
192
        from ray.tune import CLIReporter

        kwargs["progress_reporter"] = CLIReporter(metric_columns=["objective"])
    if "keep_checkpoints_num" in kwargs and kwargs["keep_checkpoints_num"] > 0:
        # `keep_checkpoints_num=0` would disabled checkpointing
        trainer.use_tune_checkpoints = True
        if kwargs["keep_checkpoints_num"] > 1:
            logger.warning(
                "Currently keeping {} checkpoints for each trial. Checkpoints are usually huge, "
                "consider setting `keep_checkpoints_num=1`."
193
            )
194
195
    if "scheduler" in kwargs:
        from ray.tune.schedulers import ASHAScheduler, HyperBandForBOHB, MedianStoppingRule, PopulationBasedTraining
196

197
198
199
200
201
202
203
204
        # Check if checkpointing is enabled for PopulationBasedTraining
        if isinstance(kwargs["scheduler"], PopulationBasedTraining):
            if not trainer.use_tune_checkpoints:
                logger.warning(
                    "You are using PopulationBasedTraining but you haven't enabled checkpointing. "
                    "This means your trials will train from scratch everytime they are exploiting "
                    "new configurations. Consider enabling checkpointing by passing "
                    "`keep_checkpoints_num=1` as an additional argument to `Trainer.hyperparameter_search`."
205
206
                )

207
208
209
        # Check for `do_eval` and `eval_during_training` for schedulers that require intermediate reporting.
        if isinstance(
            kwargs["scheduler"], (ASHAScheduler, MedianStoppingRule, HyperBandForBOHB, PopulationBasedTraining)
210
        ) and (not trainer.args.do_eval or trainer.args.evaluation_strategy == EvaluationStrategy.NO):
211
212
213
214
215
            raise RuntimeError(
                "You are using {cls} as a scheduler but you haven't enabled evaluation during training. "
                "This means your trials will not report intermediate results to Ray Tune, and "
                "can thus not be stopped early or used to exploit other trials parameters. "
                "If this is what you want, do not use {cls}. If you would like to use {cls}, "
216
                "make sure you pass `do_eval=True` and `evaluation_strategy='steps'` in the "
217
218
                "Trainer `args`.".format(cls=type(kwargs["scheduler"]).__name__)
            )
219

220
221
222
223
224
225
    analysis = ray.tune.run(
        ray.tune.with_parameters(_objective, local_trainer=trainer),
        config=trainer.hp_space(None),
        num_samples=n_trials,
        **kwargs,
    )
226
227
    best_trial = analysis.get_best_trial(metric="objective", mode=direction[:3])
    best_run = BestRun(best_trial.trial_id, best_trial.last_result["objective"], best_trial.config)
Sylvain Gugger's avatar
Sylvain Gugger committed
228
229
    if _tb_writer is not None:
        trainer.add_callback(_tb_writer)
230
    return best_run
Sylvain Gugger's avatar
Sylvain Gugger committed
231
232


233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def get_available_reporting_integrations():
    integrations = []
    if is_azureml_available():
        integrations.append("azure_ml")
    if is_comet_available():
        integrations.append("comet_ml")
    if is_mlflow_available():
        integrations.append("mlflow")
    if is_tensorboard_available():
        integrations.append("tensorboard")
    if is_wandb_available():
        integrations.append("wandb")
    return integrations


248
249
250
251
252
253
254
255
256
257
258
259
def rewrite_logs(d):
    new_d = {}
    eval_prefix = "eval_"
    eval_prefix_len = len(eval_prefix)
    for k, v in d.items():
        if k.startswith(eval_prefix):
            new_d["eval/" + k[eval_prefix_len:]] = v
        else:
            new_d["train/" + k] = v
    return new_d


260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
def init_deepspeed(trainer, num_training_steps):
    """
    Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration

    Args:
        trainer: Trainer object
        num_training_steps: per single gpu

    Returns: model, optimizer, lr_scheduler
    """
    import deepspeed

    args = trainer.args
    ds_config_file = args.deepspeed
    model = trainer.model

    with io.open(ds_config_file, "r", encoding="utf-8") as f:
        config = json.load(f)

    # The following code translates relevant trainer's cl args into the DS config

    # First to ensure that there is no mismatch between cl args values and presets in the config
    # file, ask to not set in ds config file:
    # - "train_batch_size",
    # - "train_micro_batch_size_per_gpu",
    # - "gradient_accumulation_steps"
    bs_keys = ["train_batch_size", "train_micro_batch_size_per_gpu"]
    if len([x for x in bs_keys if x in config.keys()]):
        raise ValueError(
            f"Do not include {bs_keys} entries in the ds config file, as they will be set via --per_device_train_batch_size or its default"
        )
    if "gradient_accumulation_steps" in config.keys():
        raise ValueError(
            "Do not include gradient_accumulation_steps entries in the ds config file, as they will be set via --gradient_accumulation_steps or its default"
        )

    # DeepSpeed does:
    #   train_batch_size = n_gpus * train_micro_batch_size_per_gpu * gradient_accumulation_steps
    # therefore we just need to set:
    config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size
    config["gradient_accumulation_steps"] = args.gradient_accumulation_steps

    if "gradient_clipping" in config:
        logger.info(
            f"Keeping the `gradient_clipping` config from {ds_config_file} intact, ignoring any gradient clipping-specific cl args"
        )
    else:  # override only if the ds config doesn't already have this section
        config["gradient_clipping"] = args.max_grad_norm

    if "optimizer" in config:
        logger.info(
            f"Keeping the `optimizer` config from {ds_config_file} intact, ignoring any optimizer-specific cl args"
        )
    else:  # override only if the ds config doesn't already have this section
        # ds supports Adam, OneBitAdam, and Lamb optimizers and can import other optimizers from torch.
        # But trainer uses AdamW by default.
        # To use other optimizers so using a different scheduler requires voiding warranty with: `zero_allow_untested_optimizer`

        optimizer_configs = {
            "AdamW": {
                "lr": args.learning_rate,
                "betas": [args.adam_beta1, args.adam_beta2],
                "eps": args.adam_epsilon,
                "weight_decay": args.weight_decay,
            }
        }
        optimizer = "AdamW"

        config["zero_allow_untested_optimizer"] = True
        config["optimizer"] = {
            "type": optimizer,
            "params": optimizer_configs[optimizer],
        }

    # DS schedulers (deepspeed/runtime/lr_schedules.py):
    #
    # DS name      | --lr_scheduler_type  | HF func                           | Notes
    # -------------| ---------------------|-----------------------------------|--------------------
    # LRRangeTest  | na                   | na                                | LRRT
    # OneCycle     | na                   | na                                | 1CLR
    # WarmupLR     | constant_with_warmup | get_constant_schedule_with_warmup | w/ warmup_min_lr=0
    # WarmupDecayLR| linear               | get_linear_schedule_with_warmup   |
    if "scheduler" in config:
        logger.info(
            f"Keeping the `scheduler` config from {ds_config_file} intact, ignoring any scheduler-specific cl args"
        )
    else:  # override only if the ds config doesn't already have this section
        if args.lr_scheduler_type == SchedulerType.LINEAR:
            scheduler = "WarmupDecayLR"
            params = {
                "last_batch_iteration": -1,
                "total_num_steps": num_training_steps,
                "warmup_min_lr": 0,
                "warmup_max_lr": args.learning_rate,
                "warmup_num_steps": args.warmup_steps,
            }
        elif args.lr_scheduler_type == SchedulerType.CONSTANT_WITH_WARMUP:
            scheduler = "WarmupLR"
            params = {
                "warmup_min_lr": 0,
                "warmup_max_lr": args.learning_rate,
                "warmup_num_steps": args.warmup_steps,
            }
        else:
            raise ValueError(f"{args.lr_scheduler_type} scheduler type is not supported by DeepSpeed")

        config["scheduler"] = {
            "type": scheduler,
            "params": params,
        }

    # fp16
    if trainer.fp16_backend is not None:
        # Deepspeed has 2 possible fp16 config entries:
        # - `fp16`: for the native amp - it has a bunch of optional params but we won't set any here unless the user did the work
        # - `amp`: which delegates amp work to apex (which needs to be available), but it cannot be used with any ZeRO features, so probably best to be avoided.
        if trainer.fp16_backend == "apex":
            if "amp" in config:
                logger.info(
                    f"Keeping the `amp` config from {ds_config_file} intact, ignoring any amp-specific cl args"
                )
            else:
                config["amp"] = {
                    "enabled": True,
                    "opt_level": args.fp16_opt_level,
                }
        elif trainer.fp16_backend == "amp":
            if "fp16" in config:
                logger.info(
                    f"Keeping the `fp16` config from {ds_config_file} intact, ignoring any fp16-specific cl args"
                )
            else:
                config["fp16"] = {
                    "enabled": True,
                }

    # for clarity extract the specific cl args that are being passed to deepspeed
    ds_args = dict(local_rank=args.local_rank)

    # init that takes part of the config via `args`, and the bulk of it via `config_params`
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    model, optimizer, _, lr_scheduler = deepspeed.initialize(
        args=SimpleNamespace(**ds_args),  # expects an obj
        model=model,
        model_parameters=model_parameters,
        config_params=config,
    )

    return model, optimizer, lr_scheduler


Sylvain Gugger's avatar
Sylvain Gugger committed
411
412
413
414
415
416
417
class TensorBoardCallback(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that sends the logs to `TensorBoard
    <https://www.tensorflow.org/tensorboard>`__.

    Args:
        tb_writer (:obj:`SummaryWriter`, `optional`):
Tiger's avatar
Tiger committed
418
            The writer to use. Will instantiate one if not set.
Sylvain Gugger's avatar
Sylvain Gugger committed
419
420
421
    """

    def __init__(self, tb_writer=None):
422
        has_tensorboard = is_tensorboard_available()
Sylvain Gugger's avatar
Sylvain Gugger committed
423
        assert (
424
            has_tensorboard
Sylvain Gugger's avatar
Sylvain Gugger committed
425
        ), "TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or install tensorboardX."
426
427
428
429
430
431
432
433
434
435
436
437
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter  # noqa: F401

                self._SummaryWriter = SummaryWriter
            except ImportError:
                try:
                    from tensorboardX import SummaryWriter

                    self._SummaryWriter = SummaryWriter
                except ImportError:
                    self._SummaryWriter = None
max yue's avatar
max yue committed
438
439
        else:
            self._SummaryWriter = None
Sylvain Gugger's avatar
Sylvain Gugger committed
440
441
        self.tb_writer = tb_writer

442
443
    def _init_summary_writer(self, args, log_dir=None):
        log_dir = log_dir or args.logging_dir
444
445
        if self._SummaryWriter is not None:
            self.tb_writer = self._SummaryWriter(log_dir=log_dir)
Sylvain Gugger's avatar
Sylvain Gugger committed
446
447

    def on_train_begin(self, args, state, control, **kwargs):
448
449
450
451
452
453
454
455
456
457
458
459
        if not state.is_world_process_zero:
            return

        log_dir = None

        if state.is_hyper_param_search:
            trial_name = state.trial_name
            if trial_name is not None:
                log_dir = os.path.join(args.logging_dir, trial_name)

        self._init_summary_writer(args, log_dir)

Sylvain Gugger's avatar
Sylvain Gugger committed
460
461
        if self.tb_writer is not None:
            self.tb_writer.add_text("args", args.to_json_string())
462
463
464
465
466
            if "model" in kwargs:
                model = kwargs["model"]
                if hasattr(model, "config") and model.config is not None:
                    model_config_json = model.config.to_json_string()
                    self.tb_writer.add_text("model_config", model_config_json)
467
468
469
            # Version of TensorBoard coming from tensorboardX does not have this method.
            if hasattr(self.tb_writer, "add_hparams"):
                self.tb_writer.add_hparams(args.to_sanitized_dict(), metric_dict={})
Sylvain Gugger's avatar
Sylvain Gugger committed
470
471

    def on_log(self, args, state, control, logs=None, **kwargs):
472
473
474
475
        if state.is_world_process_zero:
            if self.tb_writer is None:
                self._init_summary_writer(args)

476
        if self.tb_writer is not None:
477
            logs = rewrite_logs(logs)
Sylvain Gugger's avatar
Sylvain Gugger committed
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.tb_writer.add_scalar(k, v, state.global_step)
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a scalar. '
                        "This invocation of Tensorboard's writer.add_scalar() "
                        "is incorrect so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )
            self.tb_writer.flush()

    def on_train_end(self, args, state, control, **kwargs):
        if self.tb_writer:
            self.tb_writer.close()


class WandbCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
500
    A :class:`~transformers.TrainerCallback` that sends the logs to `Weight and Biases <https://www.wandb.com/>`__.
Sylvain Gugger's avatar
Sylvain Gugger committed
501
502
503
    """

    def __init__(self):
504
505
506
507
508
509
510
511
512
513
514
515
        has_wandb = is_wandb_available()
        assert has_wandb, "WandbCallback requires wandb to be installed. Run `pip install wandb`."
        if has_wandb:
            import wandb

            wandb.ensure_configured()
            if wandb.api.api_key is None:
                has_wandb = False
                logger.warning(
                    "W&B installed but not logged in. Run `wandb login` or set the WANDB_API_KEY env variable."
                )
                self._wandb = None
Boris Dayma's avatar
Boris Dayma committed
516
517
            else:
                self._wandb = wandb
Sylvain Gugger's avatar
Sylvain Gugger committed
518
        self._initialized = False
519
520
        # log outputs
        self._log_model = os.getenv("WANDB_LOG_MODEL", "FALSE").upper() in ENV_VARS_TRUE_VALUES.union({"TRUE"})
Sylvain Gugger's avatar
Sylvain Gugger committed
521

522
    def setup(self, args, state, model, reinit, **kwargs):
Sylvain Gugger's avatar
Sylvain Gugger committed
523
524
525
        """
        Setup the optional Weights & Biases (`wandb`) integration.

Sylvain Gugger's avatar
Sylvain Gugger committed
526
527
        One can subclass and override this method to customize the setup if needed. Find more information `here
        <https://docs.wandb.com/huggingface>`__. You can also override the following environment variables:
Sylvain Gugger's avatar
Sylvain Gugger committed
528
529

        Environment:
530
531
            WANDB_LOG_MODEL (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to log model as artifact at the end of training.
Sylvain Gugger's avatar
Sylvain Gugger committed
532
533
534
535
536
537
538
539
            WANDB_WATCH (:obj:`str`, `optional` defaults to :obj:`"gradients"`):
                Can be :obj:`"gradients"`, :obj:`"all"` or :obj:`"false"`. Set to :obj:`"false"` to disable gradient
                logging or :obj:`"all"` to log gradients and parameters.
            WANDB_PROJECT (:obj:`str`, `optional`, defaults to :obj:`"huggingface"`):
                Set this to a custom string to store results in a different project.
            WANDB_DISABLED (:obj:`bool`, `optional`, defaults to :obj:`False`):
                Whether or not to disable wandb entirely.
        """
540
541
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
542
543
544
545
546
547
        self._initialized = True
        if state.is_world_process_zero:
            logger.info(
                'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
            )
            combined_dict = {**args.to_sanitized_dict()}
548
549
550
551
552
553
554
555
556
557
558
559

            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            trial_name = state.trial_name
            init_args = {}
            if trial_name is not None:
                run_name = trial_name
                init_args["group"] = args.run_name
            else:
                run_name = args.run_name

560
            self._wandb.init(
561
562
563
564
565
566
567
                project=os.getenv("WANDB_PROJECT", "huggingface"),
                config=combined_dict,
                name=run_name,
                reinit=reinit,
                **init_args,
            )

Sylvain Gugger's avatar
Sylvain Gugger committed
568
569
            # keep track of model topology and gradients, unsupported on TPU
            if not is_torch_tpu_available() and os.getenv("WANDB_WATCH") != "false":
570
571
572
                self._wandb.watch(
                    model, log=os.getenv("WANDB_WATCH", "gradients"), log_freq=max(100, args.logging_steps)
                )
Sylvain Gugger's avatar
Sylvain Gugger committed
573
574

    def on_train_begin(self, args, state, control, model=None, **kwargs):
575
576
        if self._wandb is None:
            return
577
578
579
        hp_search = state.is_hyper_param_search
        if not self._initialized or hp_search:
            self.setup(args, state, model, reinit=hp_search, **kwargs)
Sylvain Gugger's avatar
Sylvain Gugger committed
580

581
    def on_train_end(self, args, state, control, model=None, tokenizer=None, **kwargs):
582
583
        if self._wandb is None:
            return
584
        # commit last step
585
586
        if state.is_world_process_zero:
            self._wandb.log({})
587
588
589
590
591
592
593
        if self._log_model and self._initialized and state.is_world_process_zero:
            from .trainer import Trainer

            fake_trainer = Trainer(args=args, model=model, tokenizer=tokenizer)
            with tempfile.TemporaryDirectory() as temp_dir:
                fake_trainer.save_model(temp_dir)
                # use run name and ensure it's a valid Artifact name
594
                artifact_name = re.sub(r"[^a-zA-Z0-9_\.\-]", "", self._wandb.run.name)
595
596
597
                metadata = (
                    {
                        k: v
598
                        for k, v in dict(self._wandb.summary).items()
599
600
601
602
603
604
605
606
                        if isinstance(v, numbers.Number) and not k.startswith("_")
                    }
                    if not args.load_best_model_at_end
                    else {
                        f"eval/{args.metric_for_best_model}": state.best_metric,
                        "train/total_floss": state.total_flos,
                    }
                )
607
                artifact = self._wandb.Artifact(name=f"run-{artifact_name}", type="model", metadata=metadata)
608
609
610
611
                for f in Path(temp_dir).glob("*"):
                    if f.is_file():
                        with artifact.new_file(f.name, mode="wb") as fa:
                            fa.write(f.read_bytes())
612
                self._wandb.run.log_artifact(artifact)
613

Sylvain Gugger's avatar
Sylvain Gugger committed
614
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
615
616
        if self._wandb is None:
            return
Sylvain Gugger's avatar
Sylvain Gugger committed
617
        if not self._initialized:
618
            self.setup(args, state, model, reinit=False)
Sylvain Gugger's avatar
Sylvain Gugger committed
619
        if state.is_world_process_zero:
620
            logs = rewrite_logs(logs)
621
            self._wandb.log(logs, step=state.global_step)
Sylvain Gugger's avatar
Sylvain Gugger committed
622
623
624
625


class CometCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
626
    A :class:`~transformers.TrainerCallback` that sends the logs to `Comet ML <https://www.comet.ml/site/>`__.
Sylvain Gugger's avatar
Sylvain Gugger committed
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
    """

    def __init__(self):
        assert _has_comet, "CometCallback requires comet-ml to be installed. Run `pip install comet-ml`."
        self._initialized = False

    def setup(self, args, state, model):
        """
        Setup the optional Comet.ml integration.

        Environment:
            COMET_MODE (:obj:`str`, `optional`):
                "OFFLINE", "ONLINE", or "DISABLED"
            COMET_PROJECT_NAME (:obj:`str`, `optional`):
                Comet.ml project name for experiments
            COMET_OFFLINE_DIRECTORY (:obj:`str`, `optional`):
                Folder to use for saving offline experiments when :obj:`COMET_MODE` is "OFFLINE"

Sylvain Gugger's avatar
Sylvain Gugger committed
645
646
        For a number of configurable items in the environment, see `here
        <https://www.comet.ml/docs/python-sdk/advanced/#comet-configuration-variables>`__.
Sylvain Gugger's avatar
Sylvain Gugger committed
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
        """
        self._initialized = True
        if state.is_world_process_zero:
            comet_mode = os.getenv("COMET_MODE", "ONLINE").upper()
            args = {"project_name": os.getenv("COMET_PROJECT_NAME", "huggingface")}
            experiment = None
            if comet_mode == "ONLINE":
                experiment = comet_ml.Experiment(**args)
                logger.info("Automatic Comet.ml online logging enabled")
            elif comet_mode == "OFFLINE":
                args["offline_directory"] = os.getenv("COMET_OFFLINE_DIRECTORY", "./")
                experiment = comet_ml.OfflineExperiment(**args)
                logger.info("Automatic Comet.ml offline logging enabled; use `comet upload` when finished")
            if experiment is not None:
                experiment._set_model_graph(model, framework="transformers")
                experiment._log_parameters(args, prefix="args/", framework="transformers")
                if hasattr(model, "config"):
                    experiment._log_parameters(model.config, prefix="config/", framework="transformers")

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            experiment = comet_ml.config.get_global_experiment()
            if experiment is not None:
                experiment._log_metrics(logs, step=state.global_step, epoch=state.epoch, framework="transformers")
677
678


679
680
681
682
683
684
685
class AzureMLCallback(TrainerCallback):
    """
    A :class:`~transformers.TrainerCallback` that sends the logs to `AzureML
    <https://pypi.org/project/azureml-sdk/>`__.
    """

    def __init__(self, azureml_run=None):
686
687
688
        assert (
            is_azureml_available()
        ), "AzureMLCallback requires azureml to be installed. Run `pip install azureml-sdk`."
689
690
691
        self.azureml_run = azureml_run

    def on_init_end(self, args, state, control, **kwargs):
692
693
        from azureml.core.run import Run

694
695
696
697
698
699
700
701
702
703
        if self.azureml_run is None and state.is_world_process_zero:
            self.azureml_run = Run.get_context()

    def on_log(self, args, state, control, logs=None, **kwargs):
        if self.azureml_run:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    self.azureml_run.log(k, v, description=k)


704
705
class MLflowCallback(TrainerCallback):
    """
Sylvain Gugger's avatar
Sylvain Gugger committed
706
    A :class:`~transformers.TrainerCallback` that sends the logs to `MLflow <https://www.mlflow.org/>`__.
707
708
709
710
711
    """

    MAX_LOG_SIZE = 100

    def __init__(self):
712
713
714
        assert is_mlflow_available(), "MLflowCallback requires mlflow to be installed. Run `pip install mlflow`."
        import mlflow

715
716
        self._initialized = False
        self._log_artifacts = False
717
        self._ml_flow = mlflow
718
719
720
721
722
723
724
725
726

    def setup(self, args, state, model):
        """
        Setup the optional MLflow integration.

        Environment:
            HF_MLFLOW_LOG_ARTIFACTS (:obj:`str`, `optional`):
                Whether to use MLflow .log_artifact() facility to log artifacts.

Sylvain Gugger's avatar
Sylvain Gugger committed
727
728
729
                This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to `True` or `1`, will copy
                whatever is in TrainerArgument's output_dir to the local or remote artifact storage. Using it without a
                remote storage will just copy the files to your artifact location.
730
731
732
733
734
        """
        log_artifacts = os.getenv("HF_MLFLOW_LOG_ARTIFACTS", "FALSE").upper()
        if log_artifacts in {"TRUE", "1"}:
            self._log_artifacts = True
        if state.is_world_process_zero:
735
            self._ml_flow.start_run()
736
737
738
739
740
741
742
            combined_dict = args.to_dict()
            if hasattr(model, "config") and model.config is not None:
                model_config = model.config.to_dict()
                combined_dict = {**model_config, **combined_dict}
            # MLflow cannot log more than 100 values in one go, so we have to split it
            combined_dict_items = list(combined_dict.items())
            for i in range(0, len(combined_dict_items), MLflowCallback.MAX_LOG_SIZE):
743
                self._ml_flow.log_params(dict(combined_dict_items[i : i + MLflowCallback.MAX_LOG_SIZE]))
744
745
746
747
748
749
750
751
752
753
754
755
        self._initialized = True

    def on_train_begin(self, args, state, control, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)

    def on_log(self, args, state, control, logs, model=None, **kwargs):
        if not self._initialized:
            self.setup(args, state, model)
        if state.is_world_process_zero:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
756
                    self._ml_flow.log_metric(k, v, step=state.global_step)
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
                else:
                    logger.warning(
                        "Trainer is attempting to log a value of "
                        '"%s" of type %s for key "%s" as a metric. '
                        "MLflow's log_metric() only accepts float and "
                        "int types so we dropped this attribute.",
                        v,
                        type(v),
                        k,
                    )

    def on_train_end(self, args, state, control, **kwargs):
        if self._initialized and state.is_world_process_zero:
            if self._log_artifacts:
                logger.info("Logging artifacts. This may take time.")
772
773
                self._ml_flow.log_artifacts(args.output_dir)
            self._ml_flow.end_run()
774
775
776
777

    def __del__(self):
        # if the previous run is not terminated correctly, the fluent API will
        # not let you start a new run before the previous one is killed
778
779
        if self._ml_flow.active_run is not None:
            self._ml_flow.end_run(status="KILLED")
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797


INTEGRATION_TO_CALLBACK = {
    "azure_ml": AzureMLCallback,
    "comet_ml": CometCallback,
    "mlflow": MLflowCallback,
    "tensorboard": TensorBoardCallback,
    "wandb": WandbCallback,
}


def get_reporting_integration_callbacks(report_to):
    for integration in report_to:
        if integration not in INTEGRATION_TO_CALLBACK:
            raise ValueError(
                f"{integration} is not supported, only {', '.join(INTEGRATION_TO_CALLBACK.keys())} are supported."
            )
    return [INTEGRATION_TO_CALLBACK[integration] for integration in report_to]