mixin.py 29 KB
Newer Older
wanglch's avatar
wanglch committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from huggingface/transformers.
import importlib
import os
import re
import shutil
import time
from pathlib import Path
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import json
import numpy as np
import peft
import safetensors
import torch
import transformers
from datasets import Dataset as HfDataset
from packaging import version
from peft import PeftModel
from torch.nn import Module
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.data.data_collator import DataCollator
from transformers.modeling_utils import unwrap_model
from transformers.trainer import (ADAPTER_CONFIG_NAME, ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME, CONFIG_NAME,
                                  PREFIX_CHECKPOINT_DIR, SAFE_WEIGHTS_NAME, TRAINER_STATE_NAME, TRAINING_ARGS_NAME,
                                  WEIGHTS_NAME, IntervalStrategy, Trainer, TrainerCallback, is_peft_available)
from transformers.trainer_utils import EvalPrediction
from transformers.training_args import TrainingArguments
from transformers.utils import is_torch_npu_available

from swift.hub import Repository
from swift.hub.check_model import check_local_model_is_latest
from swift.torchacc_utils import (save_ta_ddp_checkpoint, save_ta_fsdp_checkpoint, ta_load_optimizer_and_scheduler,
                                  ta_save_optimizer_and_scheduler, ta_trim_graph)
from swift.tuners import SwiftModel
from swift.utils import check_json_format, create_ms_repo, get_logger, use_torchacc
from swift.utils.constants import Invoke
from .optimizers.galore import create_optimizer_and_scheduler
from .utils import can_return_loss, find_labels, get_function, is_instance_of_ms_model

logger = get_logger()


def _push_to_hub(self: Repository, commit_message: str = 'Commit files to Modelscope Hub', **kwargs):
    blocking = kwargs.get('blocking', True)
    self.push(commit_message)
    if not blocking:
        # Compatible with transformers
        return None, None
    else:
        return None


class PushToMsHubMixin:
    repo: Repository

    def _add_patterns_to_file(self, file_name: str, patterns: List[str], commit_message: Optional[str] = None) -> None:
        # Make sure we only do this on the main process
        if not self.is_world_process_zero():
            return
        if isinstance(patterns, str):
            patterns = [patterns]
        if commit_message is None:
            commit_message = f'Add `{patterns[0]}` patterns to {file_name}'

        # Get current file content
        repo_dir = self.repo.model_dir
        file_path = os.path.join(repo_dir, file_name)
        if os.path.exists(file_path):
            with open(file_path, 'r', encoding='utf-8') as f:
                current_content = f.read()
        else:
            current_content = ''
        # Add the patterns to file
        content = current_content
        for pattern in patterns:
            if pattern not in content:
                if len(content) > 0 and not content.endswith('\n'):
                    content += '\n'
                content += f'{pattern}\n'

        # Write the file if it has changed
        if content != current_content:
            with open(file_path, 'w', encoding='utf-8') as f:
                logger.debug(f'Writing {file_name} file. Content: {content}')
                f.write(content)
        self.repo.push(commit_message)

    def _add_patterns_to_gitignore(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
        self._add_patterns_to_file('.gitignore', patterns, commit_message)

    def _add_patterns_to_gitattributes(self, patterns: List[str], commit_message: Optional[str] = None) -> None:
        new_patterns = []
        suffix = 'filter=lfs diff=lfs merge=lfs -text'
        for pattern in patterns:
            if suffix not in pattern:
                pattern = f'{pattern} {suffix}'
            new_patterns.append(pattern)
        file_name = '.gitattributes'
        if commit_message is None:
            commit_message = f'Add `{patterns[0]}` patterns to {file_name}'
        self._add_patterns_to_file(file_name, new_patterns, commit_message)

    def init_hf_repo(self) -> None:
        """init ms repo. Compatible with transformers>=4.34"""
        self.init_git_repo(at_init=True)

    def init_git_repo(self, at_init: bool = False) -> None:
        if not self.is_world_process_zero():
            return
        if (os.path.exists(self.args.output_dir) and os.listdir(self.args.output_dir) and self.args.overwrite_output_dir
                and at_init):
            # directory not empty.
            shutil.rmtree(self.args.output_dir)
        self.args.hub_model_id = create_ms_repo(self.args.hub_model_id, self.args.hub_token, self.args.hub_private_repo)
        self.repo = Repository(self.args.output_dir, self.args.hub_model_id)
        self._add_patterns_to_gitattributes(['*.safetensors', '*.bin', '*.pt'])
        self.repo.push_to_hub = MethodType(_push_to_hub, self.repo)
        self.repo.local_dir = self.repo.model_dir  # hf compatibility

        # By default, ignore the checkpoint folders
        if self.args.push_hub_strategy != 'all_checkpoints':
            self._add_patterns_to_gitignore(['checkpoint-*/', 'tmp-checkpoint-*/'])

        # Add 'runs/' to .gitignore, ignore tensorboard files
        self._add_patterns_to_gitignore(['runs/'])

        # Add '*.sagemaker' to .gitignore if using SageMaker
        if os.environ.get('SM_TRAINING_ENV'):
            self._add_patterns_to_gitignore(['*.sagemaker-uploading', '*.sagemaker-uploaded'],
                                            'Add `*.sagemaker` patterns to .gitignore')

        self.push_in_progress = None

    def push_to_hub(self, commit_message: str = 'End of training', **kwargs) -> None:
        # user calls manually `push_to_hub` with `self.args.push_to_hub = False`
        create_model_card = kwargs.pop('create_model_card', None)
        if not hasattr(self, 'repo'):
            self.init_git_repo()
        self.save_model(_internal_call=True)

        if not self.is_world_process_zero():
            return

        self.repo.push_to_hub(commit_message, **kwargs)
        # push separately the model card to be independent from the rest of the model
        readme_path = os.path.join(self.args.output_dir, 'README.md')
        if create_model_card is None:
            create_model_card = not os.path.exists(readme_path)
        if create_model_card and self.args.should_save:
            model_name = kwargs.pop('model_name', None)
            if model_name is None and self.args.should_save:
                if self.args.hub_model_id is not None:
                    model_name = self.args.hub_model_id.split('/')[-1]
                else:
                    model_name = os.path.basename(self.args.output_dir)
            self.create_model_card(model_name=model_name, **kwargs)
            self.repo.push_to_hub('update model card README.md', **kwargs)

    def _push_from_checkpoint(self, checkpoint_folder: str) -> None:
        """Compatible with transformers>=4.32"""
        # Only push from one node.
        if not self.is_world_process_zero() or self.args.push_hub_strategy == 'end':
            return
        output_dir = self.args.output_dir
        # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder
        modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME]
        if is_peft_available():
            modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME])
        for modeling_file in modeling_files:
            if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)):
                shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file))
        # Saving the tokenizer is fast and we don't know how many files it may have spawned, so we resave it to be sure.
        if self.tokenizer is not None:
            self.tokenizer.save_pretrained(output_dir)
        # Same for the training arguments
        torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

        try:
            if self.args.push_hub_strategy == 'checkpoint':
                # Temporarily move the checkpoint just saved for the push
                tmp_checkpoint = os.path.join(output_dir, 'last-checkpoint')
                # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a
                # subfolder.
                if os.path.isdir(tmp_checkpoint):
                    shutil.rmtree(tmp_checkpoint)
                shutil.move(checkpoint_folder, tmp_checkpoint)

            if self.args.save_strategy == IntervalStrategy.STEPS:
                commit_message = f'Training in progress, step {self.state.global_step}'
            else:
                commit_message = f'Training in progress, epoch {int(self.state.epoch)}'
            if self.args.push_hub_strategy == 'push_best':
                folder, checkpoint_name = os.path.split(checkpoint_folder)
                checkpoint_name = checkpoint_name.replace('tmp-checkpoint-', 'checkpoint-')
                last_model_checkpoint = os.path.join(folder, checkpoint_name)
                if last_model_checkpoint == self.state.best_model_checkpoint:
                    self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
            else:
                self.repo.push_to_hub(commit_message=commit_message, blocking=False, auto_lfs_prune=True)
        except Exception as e:
            logger.error(f'Error when pushing to hub: {e}')
        finally:
            if self.args.push_hub_strategy == 'checkpoint':
                # Move back the checkpoint to its place
                shutil.move(tmp_checkpoint, checkpoint_folder)


class SwiftMixin:

    def __init__(self,
                 model: Union[PreTrainedModel, Module] = None,
                 args: TrainingArguments = None,
                 data_collator: Optional[DataCollator] = None,
                 train_dataset: Optional[HfDataset] = None,
                 eval_dataset: Optional[Union[HfDataset, Dict[str, HfDataset]]] = None,
                 tokenizer: Optional[PreTrainedTokenizerBase] = None,
                 model_init: Optional[Callable[[], PreTrainedModel]] = None,
                 compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
                 callbacks: Optional[List[TrainerCallback]] = None,
                 optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
                 preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
                 **kwargs) -> None:
        check_model = kwargs.pop('check_model', True)
        if check_model and hasattr(model, 'model_dir'):
            check_local_model_is_latest(
                model.model_dir,
                user_agent={
                    Invoke.KEY: Invoke.LOCAL_TRAINER,
                    Invoke.THIRD_PARTY: kwargs.pop(Invoke.THIRD_PARTY, Invoke.SWIFT),
                })

        # Compatible with transformers>=4.34
        from swift.tuners import SwiftModel
        is_quantized = getattr(model, 'is_quantized', False)
        _hf_peft_config_loaded = getattr(model, '_hf_peft_config_loaded', False)
        use_swift = isinstance(model, SwiftModel)
        if is_quantized and use_swift:
            model._hf_peft_config_loaded = True
        # mro
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            **kwargs)
        if not self.label_names:
            self.label_names = ['labels']
        if is_quantized and use_swift:
            model._hf_peft_config_loaded = _hf_peft_config_loaded

        if get_function(model.__class__.forward) is not get_function(model.forward):
            self.label_names = find_labels(model)
            self.can_return_loss = can_return_loss(model)
        self.max_memory = 0.0
        self.start_time = time.time()

    @staticmethod
    def _create_configuration_file(model: Module, output_dir: str) -> None:
        cfg = getattr(model, 'cfg', {})
        configuration_path = os.path.join(output_dir, 'configuration.json')
        new_cfg = {}
        if os.path.exists(configuration_path):
            with open(configuration_path, 'r', encoding='utf-8') as f:
                new_cfg = json.load(f)

        if 'framework' not in new_cfg:
            new_cfg['framework'] = cfg.get('framework', 'pytorch')
        if 'task' not in new_cfg:
            new_cfg['task'] = cfg.get('task', 'text-generation')
        with open(configuration_path, 'w', encoding='utf-8') as f:
            json.dump(new_cfg, f, ensure_ascii=False, indent=4)

    def _add_adapter_cfg(self, output_dir: str) -> None:
        if not hasattr(self, 'sft_args'):
            return
        sft_args = self.sft_args
        if sft_args.sft_type == 'full':
            return
        configuration_path = os.path.join(output_dir, 'configuration.json')
        new_cfg = {}
        if os.path.exists(configuration_path):
            with open(configuration_path, 'r', encoding='utf-8') as f:
                new_cfg = json.load(f)

        need_to_save = [
            'model_id_or_path', 'model_revision', 'sft_type', 'tuner_backend', 'template_type', 'dtype', 'system'
        ]
        quantization_bit = sft_args.quantization_bit
        if quantization_bit > 0:
            need_to_save += [
                'quantization_bit', 'bnb_4bit_comp_dtype', 'bnb_4bit_quant_type', 'bnb_4bit_use_double_quant'
            ]
        adapter_cfg = {}
        for k in need_to_save:
            adapter_cfg[k] = getattr(sft_args, k)
        new_cfg['adapter_cfg'] = adapter_cfg
        with open(configuration_path, 'w', encoding='utf-8') as f:
            json.dump(new_cfg, f, ensure_ascii=False, indent=4)

    def _save_sft_args(self, output_dir: str) -> None:
        sft_args = getattr(self, 'sft_args', None)
        if sft_args is None:
            return
        fpath = os.path.join(output_dir, 'sft_args.json')
        with open(fpath, 'w', encoding='utf-8') as f:
            json.dump(check_json_format(self.sft_args.__dict__), f, ensure_ascii=False, indent=2)
        return

    def _save_optimizer_and_scheduler(self, output_dir):
        if not (use_torchacc() and self.sft_args.fsdp_num > 1):
            return super()._save_optimizer_and_scheduler(output_dir)

        ta_save_optimizer_and_scheduler(self.optimizer, self.lr_scheduler, output_dir)

    def _load_optimizer_and_scheduler(self, checkpoint):
        if not (use_torchacc() and self.sft_args.fsdp_num > 1):
            return super()._load_optimizer_and_scheduler(checkpoint)

        if checkpoint is None or self.args.save_only_model:
            return

        self.optimizer, self.lr_scheduler = ta_load_optimizer_and_scheduler(self.optimizer, self.lr_scheduler,
                                                                            checkpoint, self.args.device)

    def _save_tpu(self, output_dir: Optional[str] = None):
        if not use_torchacc():
            return super()._save_tpu(output_dir)

        output_dir = output_dir if output_dir is not None else self.args.output_dir
        if self.sft_args.fsdp_num > 1:
            save_ta_fsdp_checkpoint(self.model, self.tokenizer, self.args, output_dir)
        else:
            save_ta_ddp_checkpoint(self.model, self.tokenizer, self.args, output_dir)

    def _save(self, output_dir: Optional[str] = None, state_dict=None):
        """Compatible with swift and peft"""
        # If we are executing this function, we are the process zero, so we don't check for that.
        output_dir = output_dir if output_dir is not None else self.args.output_dir
        os.makedirs(output_dir, exist_ok=True)
        # configuration.json
        model_dir = getattr(self.model, 'model_dir', None)
        if model_dir is not None:
            src_path = os.path.join(model_dir, 'configuration.json')
            dst_path = os.path.join(output_dir, 'configuration.json')
            if os.path.exists(src_path):
                shutil.copy(src_path, dst_path)
        else:
            self._create_configuration_file(self.model, output_dir)
        self._add_adapter_cfg(output_dir)
        self._save_sft_args(output_dir)
        # generation_config
        generation_config = getattr(self.args, 'generation_config', None)
        if generation_config is not None:
            generation_config.save_pretrained(output_dir)
        # model
        supported_classes = (SwiftModel, PreTrainedModel, PeftModel)
        save_safetensors = self.args.save_safetensors
        if not isinstance(self.model, supported_classes):
            if state_dict is None:
                state_dict = self.model.state_dict()

            _unwrap_model = unwrap_model(self.model)
            if isinstance(_unwrap_model, supported_classes):
                _unwrap_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
            else:
                logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.')
                if save_safetensors:
                    safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors'))
                else:
                    torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin'))
        elif is_instance_of_ms_model(self.model):
            PreTrainedModel.save_pretrained(
                self.model, output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
        else:
            self.model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=save_safetensors)
        sft_args = getattr(self, 'sft_args', None)
        # tokenizer
        if self.tokenizer is not None and sft_args is not None and sft_args.sft_type == 'full':
            if hasattr(self.tokenizer, 'processor'):
                self.tokenizer.processor.save_pretrained(output_dir)
            self.tokenizer.save_pretrained(output_dir)
        # training_args.bin
        torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))
        # additional files
        if sft_args is not None and sft_args.sft_type == 'full':
            additional_files = getattr(self.args, 'additional_saved_files', []) + ['preprocessor_config.json']
            if model_dir is not None:
                for file in additional_files:
                    src_path = os.path.join(model_dir, file)
                    dst_path = os.path.join(output_dir, file)
                    if os.path.isfile(src_path):
                        shutil.copy(src_path, dst_path)
                    elif os.path.isdir(src_path):
                        shutil.copytree(src_path, dst_path)

    def _save_checkpoint(self, model, trial, metrics=None):
        self.state.last_model_checkpoint = os.path.join(self.args.output_dir, f'checkpoint-{self.state.global_step}')
        logger.info(f'Saving model checkpoint to {self.state.last_model_checkpoint}')
        if version.parse(transformers.__version__) >= version.parse('4.36') or not self.args.save_only_model:
            return super()._save_checkpoint(model, trial, metrics)
        else:
            return self._save_only_model(model, trial, metrics)

    def _save_only_model(self, model, trial, metrics=None):
        # Save model checkpoint
        checkpoint_folder = f'{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}'

        if self.hp_search_backend is None and trial is None:
            self.store_flos()

        run_dir = self._get_output_dir(trial=trial)
        output_dir = os.path.join(run_dir, checkpoint_folder)
        self.save_model(output_dir, _internal_call=True)

        # Determine the new best metric / best model checkpoint
        if metrics is not None and self.args.metric_for_best_model is not None:
            metric_to_check = self.args.metric_for_best_model
            if not metric_to_check.startswith('eval_'):
                metric_to_check = f'eval_{metric_to_check}'
            metric_value = metrics[metric_to_check]

            operator = np.greater if self.args.greater_is_better else np.less
            if (self.state.best_metric is None or self.state.best_model_checkpoint is None
                    or operator(metric_value, self.state.best_metric)):
                self.state.best_metric = metric_value
                self.state.best_model_checkpoint = output_dir

        # Save the Trainer state
        if self.args.should_save:
            self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

        # push to hub
        if self.args.push_to_hub:
            self._push_from_checkpoint(output_dir)

        # Maybe delete some older checkpoints.
        if self.args.should_save:
            self._rotate_checkpoints(use_mtime=True, output_dir=run_dir)

    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        train_sampler_random = self.args.train_sampler_random
        if train_sampler_random:
            return super()._get_train_sampler()
        else:
            return self._get_eval_sampler(self.train_dataset)

    def _load_from_checkpoint(self, resume_from_checkpoint: str, model=None) -> None:
        if model is None:
            model = self.model
        if use_torchacc():
            # Loading checkpoint of TorchAcc has been done in tuner.py when
            # sft_type is 'full'.
            if self.sft_args.fsdp_num > 1:
                model = model._get_underlay_model().module.module
            if isinstance(model, PreTrainedModel):
                return
        elif not isinstance(model, SwiftModel):
            # Avoid throwing exceptions
            return super()._load_from_checkpoint(resume_from_checkpoint, model)

    def _sorted_checkpoints(self,
                            output_dir=None,
                            checkpoint_prefix=PREFIX_CHECKPOINT_DIR,
                            use_mtime=False) -> List[str]:
        ordering_and_checkpoint_path = []

        glob_checkpoints = [str(x) for x in Path(output_dir).glob(f'{checkpoint_prefix}-*') if os.path.isdir(x)]

        for path in glob_checkpoints:
            if use_mtime:
                ordering_and_checkpoint_path.append((os.path.getmtime(path), path))
            else:
                regex_match = re.match(f'.*{checkpoint_prefix}-([0-9]+)', path)
                if regex_match is not None and regex_match.groups() is not None:
                    ordering_and_checkpoint_path.append((int(regex_match.groups()[0]), path))

        checkpoints_sorted = sorted(ordering_and_checkpoint_path)
        checkpoints_sorted = [checkpoint[1] for checkpoint in checkpoints_sorted]
        # Make sure we don't delete the best model.
        if (self.state.best_model_checkpoint is not None
                and str(Path(self.state.best_model_checkpoint)) in checkpoints_sorted):
            best_model_index = checkpoints_sorted.index(str(Path(self.state.best_model_checkpoint)))
            for i in range(best_model_index, len(checkpoints_sorted) - 2):
                checkpoints_sorted[i], checkpoints_sorted[i + 1] = checkpoints_sorted[i + 1], checkpoints_sorted[i]
        return checkpoints_sorted

    def _load_best_model(self):
        # Compatible with transformers>=4.35 (deepspeed)
        try:
            model = self.model
            if isinstance(model, SwiftModel):
                logger.info(
                    f'Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).')
                adapters = model.adapters
                for adapter_name in adapters.keys():
                    sub_folder = os.path.join(self.state.best_model_checkpoint, adapter_name)
                    state_dict = SwiftModel.load_state_file(sub_folder, device='cpu')
                    if state_dict is not None:
                        self.model.load_state_dict(state_dict, strict=False, adapter_name=adapter_name)
                state_dict = SwiftModel.load_state_file(self.state.best_model_checkpoint, device='cpu')
                if state_dict is not None:
                    self.model.load_state_dict(state_dict, strict=False, adapter_name='default')
            else:
                super()._load_best_model()
        except ValueError as e:
            logger.warning(e)

    def get_max_cuda_memory(self, device: Optional[Union[torch.device, int]] = None) -> float:
        if device is None:
            mems = [torch.cuda.max_memory_reserved(device=device) for device in range(torch.cuda.device_count())]
        else:
            mems = [torch.cuda.max_memory_reserved(device=device)]
        mem = sum([float(mem) / 1024 / 1024 / 1024 for mem in mems])
        if self.max_memory < mem:
            self.max_memory = mem
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
        return mem

    def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
        if self.control.should_log:
            if use_torchacc():
                ta_trim_graph()
            self.control.should_log = False
            logs: Dict[str, float] = {}
            metrics_log = {'loss': tr_loss}  # loss first
            if hasattr(self, '_custom_metrics'):
                metrics_log.update(self._custom_metrics)
                self._custom_metrics = {}
            for k, v in metrics_log.items():
                # all_gather + mean() to get average loss over all processes
                v_scalar = self._nested_gather(v).mean().item()
                if k == 'loss':
                    self._total_loss_scalar += v_scalar
                logs[k] = round(v_scalar / (self.state.global_step - self._globalstep_last_logged), 8)
            if version.parse(transformers.__version__) >= version.parse('4.38'):
                grad_norm = args[0]
                if isinstance(grad_norm, torch.Tensor):
                    grad_norm = grad_norm.item()
                if grad_norm is not None:
                    logs['grad_norm'] = grad_norm
            logs['learning_rate'] = self._get_learning_rate()
            if not is_torch_npu_available():
                logs['memory(GiB)'] = round(self.get_max_cuda_memory(), 2)
            import time
            time_now = time.time()
            elapse_time = time_now - self.start_time
            logs['train_speed(iter/s)'] = round(self.state.global_step / elapse_time, 6)
            tr_loss -= tr_loss
            self._globalstep_last_logged = self.state.global_step
            self.store_flos()
            self.log(logs)
        super()._maybe_log_save_evaluate(tr_loss, *args, **kwargs)

    def create_optimizer_and_scheduler(self, num_training_steps: int):
        if hasattr(self.args, 'galore_config'):
            optimizer, lr_scheduler = create_optimizer_and_scheduler(
                self.model,
                self.args,
                self.args.galore_config,
                num_training_steps,
                lr=self.args.learning_rate,
                weight_decay=self.args.weight_decay)
            self.optimizer = optimizer
            self.lr_scheduler = lr_scheduler
        else:
            self.create_optimizer()
            self.create_scheduler(num_training_steps=num_training_steps, optimizer=self.optimizer)

    def create_optimizer(self):
        opt_model = self.model

        if self.optimizer is None:
            if version.parse(transformers.__version__) < version.parse('4.34.0'):
                logger.warning(f'If you are using lora+, please remember using transformers>=4.34.0, '
                               f'but now is {transformers.__version__}')
                return super().create_optimizer()

            optimizer_grouped_parameters = None
            if hasattr(self.model, 'create_optimizer_param_groups'):
                # Lora+ parameter groups
                optimizer_grouped_parameters = self.model.create_optimizer_param_groups(
                    lr=self.args.learning_rate, weight_decay=self.args.weight_decay)

            if optimizer_grouped_parameters is None:
                # Default parameter groups
                decay_parameters = self.get_decay_parameter_names(opt_model)
                optimizer_grouped_parameters = [
                    {
                        'params':
                        [p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)],
                        'weight_decay':
                        self.args.weight_decay,
                    },
                    {
                        'params':
                        [p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)],
                        'weight_decay':
                        0.0,
                    },
                ]

            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
        return self.optimizer