finetune.py 17.8 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
# -*- coding: utf-8 -*-
import dataclasses as dc
import functools
Rayyyyy's avatar
Rayyyyy committed
4
import os
Rayyyyy's avatar
Rayyyyy committed
5
6
7
from collections.abc import Callable, Mapping, Sequence
from pathlib import Path
from typing import Annotated, Any, Optional, Union
Rayyyyy's avatar
Rayyyyy committed
8
9

import jieba
Rayyyyy's avatar
Rayyyyy committed
10
11
12
13
import numpy as np
import ruamel.yaml as yaml
import torch
import typer
Rayyyyy's avatar
Rayyyyy committed
14
15
from datasets import Dataset, DatasetDict, NamedSplit, Split, load_dataset
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu
Rayyyyy's avatar
Rayyyyy committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from peft import PeftConfig, get_peft_config, get_peft_model
from rouge_chinese import Rouge
from torch import nn
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    EvalPrediction,
    GenerationConfig,
    PreTrainedTokenizer,
    Seq2SeqTrainingArguments,
)
from transformers import DataCollatorForSeq2Seq as _DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer as _Seq2SeqTrainer

Rayyyyy's avatar
Rayyyyy committed
30
31
32
33
34

# For Ascend NPU, please add this
# import torch_npu
# from torch_npu.contrib import transfer_to_npu

Rayyyyy's avatar
Rayyyyy committed
35
36
37
38
39
app = typer.Typer(pretty_exceptions_show_locals=False)


class DataCollatorForSeq2Seq(_DataCollatorForSeq2Seq):
    def __call__(self, features, return_tensors=None):
Rayyyyy's avatar
Rayyyyy committed
40
        output_ids = [feature["output_ids"] for feature in features] if "output_ids" in features[0].keys() else None
Rayyyyy's avatar
Rayyyyy committed
41
42
43
44
        if output_ids is not None:
            max_output_length = max(len(out) for out in output_ids)
            if self.pad_to_multiple_of is not None:
                max_output_length = (
Rayyyyy's avatar
Rayyyyy committed
45
46
47
                    (max_output_length + self.pad_to_multiple_of - 1)
                    // self.pad_to_multiple_of
                    * self.pad_to_multiple_of
Rayyyyy's avatar
Rayyyyy committed
48
49
                )
            for feature in features:
Rayyyyy's avatar
Rayyyyy committed
50
51
52
                remainder = [self.tokenizer.pad_token_id] * (max_output_length - len(feature["output_ids"]))
                if isinstance(feature["output_ids"], list):
                    feature["output_ids"] = feature["output_ids"] + remainder
Rayyyyy's avatar
Rayyyyy committed
53
                else:
Rayyyyy's avatar
Rayyyyy committed
54
                    feature["output_ids"] = np.concatenate([feature["output_ids"], remainder]).astype(np.int64)
Rayyyyy's avatar
Rayyyyy committed
55
56
57
58
59
        return super().__call__(features, return_tensors)


class Seq2SeqTrainer(_Seq2SeqTrainer):
    def prediction_step(
Rayyyyy's avatar
Rayyyyy committed
60
61
62
63
64
65
        self,
        model: nn.Module,
        inputs: dict[str, Any],
        prediction_loss_only: bool,
        ignore_keys=None,
        **gen_kwargs,
Rayyyyy's avatar
Rayyyyy committed
66
    ) -> tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
Rayyyyy's avatar
Rayyyyy committed
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
        with torch.no_grad():  # Ensure no gradient computation
            if self.args.predict_with_generate:
                output_ids = inputs.pop("output_ids")
            input_ids = inputs["input_ids"]

            del inputs["labels"]
            loss, generated_tokens, labels = super().prediction_step(
                model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs
            )

            generated_tokens = generated_tokens[:, input_ids.size()[1] :]
            labels = output_ids

            del inputs, input_ids, output_ids
            torch.cuda.empty_cache()

Rayyyyy's avatar
Rayyyyy committed
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
        return loss, generated_tokens, labels


@dc.dataclass
class DataConfig(object):
    train_file: Optional[str] = None
    val_file: Optional[str] = None
    test_file: Optional[str] = None
    num_proc: Optional[int] = None

    @property
    def data_format(self) -> str:
        return Path(self.train_file).suffix

    @property
    def data_files(self) -> dict[NamedSplit, str]:
        return {
            split: data_file
            for split, data_file in zip(
                [Split.TRAIN, Split.VALIDATION, Split.TEST],
                [self.train_file, self.val_file, self.test_file],
            )
            if data_file is not None
        }


@dc.dataclass
class FinetuningConfig(object):
    data_config: DataConfig

    max_input_length: int
    max_output_length: int
Rayyyyy's avatar
Rayyyyy committed
115
116
    combine: bool
    freezeV: bool
Rayyyyy's avatar
Rayyyyy committed
117
118

    training_args: Seq2SeqTrainingArguments = dc.field(
Rayyyyy's avatar
Rayyyyy committed
119
        default_factory=lambda: Seq2SeqTrainingArguments(output_dir="./output")
Rayyyyy's avatar
Rayyyyy committed
120
121
    )
    peft_config: Optional[PeftConfig] = None
Rayyyyy's avatar
Rayyyyy committed
122
    swanlab: Optional[str] = "cloud"
Rayyyyy's avatar
Rayyyyy committed
123
124
125
126

    def __post_init__(self):
        if not self.training_args.do_eval or self.data_config.val_file is None:
            self.training_args.do_eval = False
Rayyyyy's avatar
Rayyyyy committed
127
            self.training_args.evaluation_strategy = "no"
Rayyyyy's avatar
Rayyyyy committed
128
129
130
            self.data_config.val_file = None
        else:
            self.training_args.per_device_eval_batch_size = (
Rayyyyy's avatar
Rayyyyy committed
131
                self.training_args.per_device_eval_batch_size or self.training_args.per_device_train_batch_size
Rayyyyy's avatar
Rayyyyy committed
132
            )
Rayyyyy's avatar
Rayyyyy committed
133
134
135
136
        if self.swanlab != "disabled":
            os.environ["SWANLAB_PROJ_NAME"] = "GLM4-Finetune"
        if self.swanlab == "local":
            os.environ["SWANLAB_MODE"] = "local"
Rayyyyy's avatar
Rayyyyy committed
137
138

    @classmethod
Rayyyyy's avatar
Rayyyyy committed
139
140
141
142
    def from_dict(cls, **kwargs) -> "FinetuningConfig":
        training_args = kwargs.get("training_args", None)
        if training_args is not None and not isinstance(training_args, Seq2SeqTrainingArguments):
            gen_config = training_args.get("generation_config")
Rayyyyy's avatar
Rayyyyy committed
143
            if not isinstance(gen_config, GenerationConfig):
Rayyyyy's avatar
Rayyyyy committed
144
145
                training_args["generation_config"] = GenerationConfig(**gen_config)
            kwargs["training_args"] = Seq2SeqTrainingArguments(**training_args)
Rayyyyy's avatar
Rayyyyy committed
146

Rayyyyy's avatar
Rayyyyy committed
147
        data_config = kwargs.get("data_config")
Rayyyyy's avatar
Rayyyyy committed
148
        if not isinstance(data_config, DataConfig):
Rayyyyy's avatar
Rayyyyy committed
149
            kwargs["data_config"] = DataConfig(**data_config)
Rayyyyy's avatar
Rayyyyy committed
150

Rayyyyy's avatar
Rayyyyy committed
151
        peft_config = kwargs.get("peft_config", None)
Rayyyyy's avatar
Rayyyyy committed
152
        if peft_config is not None and not isinstance(peft_config, PeftConfig):
Rayyyyy's avatar
Rayyyyy committed
153
            kwargs["peft_config"] = get_peft_config(config_dict=peft_config)
Rayyyyy's avatar
Rayyyyy committed
154
155
156
        return cls(**kwargs)

    @classmethod
Rayyyyy's avatar
Rayyyyy committed
157
    def from_file(cls, path: Union[str, Path]) -> "FinetuningConfig":
Rayyyyy's avatar
Rayyyyy committed
158
        path = Path(path)
Rayyyyy's avatar
Rayyyyy committed
159
        parser = yaml.YAML(typ="safe", pure=True)
Rayyyyy's avatar
Rayyyyy committed
160
161
162
163
164
165
166
        parser.indent(mapping=2, offset=2, sequence=4)
        parser.default_flow_style = False
        kwargs = parser.load(path)
        return cls.from_dict(**kwargs)


def _load_datasets(
Rayyyyy's avatar
Rayyyyy committed
167
168
169
170
    data_dir: str,
    data_format: str,
    data_files: dict[NamedSplit, str],
    num_proc: Optional[int],
Rayyyyy's avatar
Rayyyyy committed
171
) -> DatasetDict:
Rayyyyy's avatar
Rayyyyy committed
172
    if data_format == ".jsonl":
Rayyyyy's avatar
Rayyyyy committed
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
        dataset_dct = load_dataset(
            data_dir,
            data_files=data_files,
            split=None,
            num_proc=num_proc,
        )
    else:
        raise NotImplementedError(f"Cannot load dataset in the '{data_format}' format.")
    return dataset_dct


class DataManager(object):
    def __init__(self, data_dir: str, data_config: DataConfig):
        self._num_proc = data_config.num_proc

        self._dataset_dct = _load_datasets(
            data_dir,
            data_config.data_format,
            data_config.data_files,
            self._num_proc,
        )

    def _get_dataset(self, split: NamedSplit) -> Optional[Dataset]:
        return self._dataset_dct.get(split, None)

    def get_dataset(
Rayyyyy's avatar
Rayyyyy committed
199
200
201
202
203
        self,
        split: NamedSplit,
        process_fn: Callable[[dict[str, Any]], dict[str, Any]],
        batched: bool = True,
        remove_orig_columns: bool = True,
Rayyyyy's avatar
Rayyyyy committed
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    ) -> Optional[Dataset]:
        orig_dataset = self._get_dataset(split)
        if orig_dataset is None:
            return

        if remove_orig_columns:
            remove_columns = orig_dataset.column_names
        else:
            remove_columns = None
        return orig_dataset.map(
            process_fn,
            batched=batched,
            remove_columns=remove_columns,
            num_proc=self._num_proc,
        )


def process_message(message):
Rayyyyy's avatar
Rayyyyy committed
222
223
224
225
226
227
    if "tools" in message and message["role"] == "system":
        for tool in message["tools"]:
            parameters = tool["function"]["parameters"]["properties"]
            tool["function"]["parameters"]["properties"] = {k: v for k, v in parameters.items() if v is not None}
    elif "tools" in message:
        del message["tools"]
Rayyyyy's avatar
Rayyyyy committed
228
229
230
231
    return message


def process_batch(
Rayyyyy's avatar
Rayyyyy committed
232
233
234
235
236
    batch: Mapping[str, Sequence],
    tokenizer: PreTrainedTokenizer,
    max_input_length: int,
    max_output_length: int,
    combine: bool,
Rayyyyy's avatar
Rayyyyy committed
237
) -> dict[str, list]:
Rayyyyy's avatar
Rayyyyy committed
238
    batched_conv = batch["messages"]
Rayyyyy's avatar
Rayyyyy committed
239
240
241
242
243
    batched_input_ids = []
    batched_labels = []
    for conv in batched_conv:
        input_ids = [151331, 151333]
        loss_masks = [False, False]
Rayyyyy's avatar
Rayyyyy committed
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        if combine:
            new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
            input_ids = new_input_ids
            loss_masks = [False] * len(input_ids)
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            for j in range(last_assistant_index + 1, len(input_ids)):
                loss_masks[j] = True
        else:
            for message in conv:
                message = process_message(message)
                loss_mask_val = False if message["role"] in ("system", "user", "observation") else True
                new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                input_ids += new_input_ids
                loss_masks += [loss_mask_val] * len(new_input_ids)

        input_ids.append(151336)  # EOS for chat
Rayyyyy's avatar
Rayyyyy committed
260
261
262
263
264
265
266
267
268
269
        loss_masks = [False, *loss_masks]
        labels = []
        for input_id, mask in zip(input_ids, loss_masks):
            if mask:
                labels.append(input_id)
            else:
                labels.append(-100)
        max_length = max_input_length + max_output_length + 1
        batched_input_ids.append(input_ids[:max_length])
        batched_labels.append(labels[:max_length])
Rayyyyy's avatar
Rayyyyy committed
270
271
272
273
274

    del batched_conv, conv, input_ids, loss_masks, new_input_ids, labels
    torch.cuda.empty_cache()

    return {"input_ids": batched_input_ids, "labels": batched_labels}
Rayyyyy's avatar
Rayyyyy committed
275
276
277


def process_batch_eval(
Rayyyyy's avatar
Rayyyyy committed
278
279
280
281
282
    batch: Mapping[str, Sequence],
    tokenizer: PreTrainedTokenizer,
    max_input_length: int,
    max_output_length: int,
    combine: bool,
Rayyyyy's avatar
Rayyyyy committed
283
) -> dict[str, list]:
Rayyyyy's avatar
Rayyyyy committed
284
    batched_conv = batch["messages"]
Rayyyyy's avatar
Rayyyyy committed
285
286
287
288
    batched_input_ids = []
    batched_output_ids = []

    for conv in batched_conv:
Rayyyyy's avatar
Rayyyyy committed
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        if combine:
            new_input_ids = tokenizer.apply_chat_template(conv, tokenize=True, return_dict=False)
            input_ids = new_input_ids
            last_assistant_index = len(input_ids) - input_ids[::-1].index(151337) - 1
            output_prompt, output_ids = (
                input_ids[:1],
                input_ids[last_assistant_index:],
            )
            output_ids.append(151336)
            batched_input_ids.append(input_ids[:max_input_length] + output_prompt[:1])
            batched_output_ids.append(output_ids[:max_output_length])
        else:
            input_ids = [151331, 151333]
            for message in conv:
                if len(input_ids) >= max_input_length:
                    break
                else:
                    message = process_message(message)
                    new_input_ids = tokenizer.apply_chat_template([message], tokenize=True, return_dict=False)[2:]
                    if message["role"] == "assistant":
                        output_prompt, output_ids = (
                            new_input_ids[:1],
                            new_input_ids[1:],
                        )
                        output_ids.append(151336)
                        batched_input_ids.append(input_ids[:max_input_length] + output_prompt[:1])
                        batched_output_ids.append(output_ids[:max_output_length])
                    input_ids += new_input_ids
Rayyyyy's avatar
Rayyyyy committed
317

Rayyyyy's avatar
Rayyyyy committed
318
319
320
321
    del batched_conv, conv, input_ids, new_input_ids, output_prompt, output_ids
    torch.cuda.empty_cache()

    return {"input_ids": batched_input_ids, "output_ids": batched_output_ids}
Rayyyyy's avatar
Rayyyyy committed
322
323
324


def load_tokenizer_and_model(
Rayyyyy's avatar
Rayyyyy committed
325
326
    model_dir: str,
    peft_config: Optional[PeftConfig] = None,
Rayyyyy's avatar
Rayyyyy committed
327
):
Rayyyyy's avatar
Rayyyyy committed
328
    tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side="left", trust_remote_code=True)
Rayyyyy's avatar
Rayyyyy committed
329
330
331
332
    if peft_config is not None:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            use_cache=False,
Rayyyyy's avatar
Rayyyyy committed
333
            torch_dtype=torch.bfloat16,  # Must use BFloat 16
Rayyyyy's avatar
Rayyyyy committed
334
335
336
337
338
339
340
        )
        model = get_peft_model(model, peft_config)
        model.print_trainable_parameters()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_dir,
            use_cache=False,
Rayyyyy's avatar
Rayyyyy committed
341
            torch_dtype=torch.bfloat16,
Rayyyyy's avatar
Rayyyyy committed
342
343
344
345
346
347
        )
    return tokenizer, model


def compute_metrics(eval_preds: EvalPrediction, tokenizer):
    batched_pred_ids, batched_label_ids = eval_preds
Rayyyyy's avatar
Rayyyyy committed
348
349
350
    batched_pred_ids[batched_pred_ids == -100] = tokenizer.pad_token_id
    batched_label_ids[batched_label_ids == -100] = tokenizer.pad_token_id
    metrics_dct = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
Rayyyyy's avatar
Rayyyyy committed
351
352
353
354
355
356
    for pred_ids, label_ids in zip(batched_pred_ids, batched_label_ids):
        pred_txt = tokenizer.decode(pred_ids).strip()
        label_txt = tokenizer.decode(label_ids).strip()
        pred_tokens = list(jieba.cut(pred_txt))
        label_tokens = list(jieba.cut(label_txt))
        rouge = Rouge()
Rayyyyy's avatar
Rayyyyy committed
357
        scores = rouge.get_scores(" ".join(pred_tokens), " ".join(label_tokens))
Rayyyyy's avatar
Rayyyyy committed
358
        for k, v in scores[0].items():
Rayyyyy's avatar
Rayyyyy committed
359
360
361
362
363
364
365
366
            metrics_dct[k].append(round(v["f"] * 100, 4))
        metrics_dct["bleu-4"].append(
            sentence_bleu(
                [label_tokens],
                pred_tokens,
                smoothing_function=SmoothingFunction().method3,
            )
        )
Rayyyyy's avatar
Rayyyyy committed
367
368
369
370
371
    return {k: np.mean(v) for k, v in metrics_dct.items()}


@app.command()
def main(
Rayyyyy's avatar
Rayyyyy committed
372
373
374
375
376
    data_dir: Annotated[str, typer.Argument(help="")],
    model_dir: Annotated[
        str,
        typer.Argument(
            help="A string that specifies the model id of a pretrained model configuration hosted on huggingface.co, or a path to a directory containing a model configuration file."
Rayyyyy's avatar
Rayyyyy committed
377
        ),
Rayyyyy's avatar
Rayyyyy committed
378
379
380
381
382
383
    ],
    config_file: Annotated[str, typer.Argument(help="")],
    auto_resume_from_checkpoint: str = typer.Argument(
        default="",
        help="If entered as yes, automatically use the latest save checkpoint. If it is a numerical example 12 15, use the corresponding save checkpoint. If the input is no, restart training",
    ),
Rayyyyy's avatar
Rayyyyy committed
384
385
386
387
388
389
390
391
392
393
):
    ft_config = FinetuningConfig.from_file(config_file)
    tokenizer, model = load_tokenizer_and_model(model_dir, peft_config=ft_config.peft_config)
    data_manager = DataManager(data_dir, ft_config.data_config)

    train_dataset = data_manager.get_dataset(
        Split.TRAIN,
        functools.partial(
            process_batch,
            tokenizer=tokenizer,
Rayyyyy's avatar
Rayyyyy committed
394
            combine=ft_config.combine,
Rayyyyy's avatar
Rayyyyy committed
395
396
397
398
399
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
Rayyyyy's avatar
Rayyyyy committed
400
    print("train_dataset:", train_dataset)
Rayyyyy's avatar
Rayyyyy committed
401
402
403
404
405
    val_dataset = data_manager.get_dataset(
        Split.VALIDATION,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
Rayyyyy's avatar
Rayyyyy committed
406
            combine=ft_config.combine,
Rayyyyy's avatar
Rayyyyy committed
407
408
409
410
411
412
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if val_dataset is not None:
Rayyyyy's avatar
Rayyyyy committed
413
        print("val_dataset:", val_dataset)
Rayyyyy's avatar
Rayyyyy committed
414
415
416
417
418
    test_dataset = data_manager.get_dataset(
        Split.TEST,
        functools.partial(
            process_batch_eval,
            tokenizer=tokenizer,
Rayyyyy's avatar
Rayyyyy committed
419
            combine=ft_config.combine,
Rayyyyy's avatar
Rayyyyy committed
420
421
422
423
424
425
            max_input_length=ft_config.max_input_length,
            max_output_length=ft_config.max_output_length,
        ),
        batched=True,
    )
    if test_dataset is not None:
Rayyyyy's avatar
Rayyyyy committed
426
        print("test_dataset:", test_dataset)
Rayyyyy's avatar
Rayyyyy committed
427

Rayyyyy's avatar
Rayyyyy committed
428
429
    ft_config.training_args.generation_config.pad_token_id = 151329
    ft_config.training_args.generation_config.eos_token_id = [151329, 151336, 151338]
Rayyyyy's avatar
Rayyyyy committed
430
431
432
433
434
435

    trainer = Seq2SeqTrainer(
        model=model,
        args=ft_config.training_args,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
Rayyyyy's avatar
Rayyyyy committed
436
437
            padding="longest",
            return_tensors="pt",
Rayyyyy's avatar
Rayyyyy committed
438
439
        ),
        train_dataset=train_dataset,
Rayyyyy's avatar
Rayyyyy committed
440
        eval_dataset=val_dataset,
Rayyyyy's avatar
Rayyyyy committed
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
        compute_metrics=functools.partial(compute_metrics, tokenizer=tokenizer),
    )

    if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:
        trainer.train()
    else:
        output_dir = ft_config.training_args.output_dir
        dirlist = os.listdir(output_dir)
        checkpoint_sn = 0
        for checkpoint_str in dirlist:
            if checkpoint_str.find("eckpoint") > 0 and checkpoint_str.find("tmp") == -1:
                checkpoint = int(checkpoint_str.replace("checkpoint-", ""))
                if checkpoint > checkpoint_sn:
                    checkpoint_sn = checkpoint
        if auto_resume_from_checkpoint.upper() == "YES":
            if checkpoint_sn > 0:
                model.gradient_checkpointing_enable()
                model.enable_input_require_grads()
                checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
Rayyyyy's avatar
Rayyyyy committed
460
                print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
Rayyyyy's avatar
Rayyyyy committed
461
462
463
464
465
466
467
468
469
470
                trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
                trainer.train()
        else:
            if auto_resume_from_checkpoint.isdigit():
                if int(auto_resume_from_checkpoint) > 0:
                    checkpoint_sn = int(auto_resume_from_checkpoint)
                    model.gradient_checkpointing_enable()
                    model.enable_input_require_grads()
                    checkpoint_directory = os.path.join(output_dir, "checkpoint-" + str(checkpoint_sn))
Rayyyyy's avatar
Rayyyyy committed
471
                    print("resume checkpoint from checkpoint-" + str(checkpoint_sn))
Rayyyyy's avatar
Rayyyyy committed
472
473
                    trainer.train(resume_from_checkpoint=checkpoint_directory)
            else:
Rayyyyy's avatar
Rayyyyy committed
474
475
476
477
478
479
                print(
                    auto_resume_from_checkpoint,
                    "The specified checkpoint sn("
                    + auto_resume_from_checkpoint
                    + ") has not been saved. Please search for the correct checkpoint in the model output directory",
                )
Rayyyyy's avatar
Rayyyyy committed
480
481
482
483
484

    if test_dataset is not None:
        trainer.predict(test_dataset)


Rayyyyy's avatar
Rayyyyy committed
485
if __name__ == "__main__":
Rayyyyy's avatar
Rayyyyy committed
486
    app()