train.py 3.51 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import time
import pandas as pd
from matplotlib import pyplot as plt
from dataset import loaddata, tokenlizer, encode, decode
from model import Transformer
from config import ModelArgs


## Train Llama 3 Model
# Define function to generate batches from the given dataset
def get_dataset_batch(data, split, args:ModelArgs):
    seq_len = args.max_seq_len
    batch_size = args.max_batch_size
    device = args.device

    train = data[:int(0.8 * len(data))]
    val = data[int(0.8 * len(data)): int(0.9 * len(data))]
    test = data[int(0.9 * len(data)):]

    batch_data = train
    if split == "val":
        batch_data = val

    if split == "test":
        batch_data = test

    # Picking random starting points from the dataset to give random samples for training, validation and testing.
    stoi, itos, token_bos, token_eos, token_pad = tokenlizer()
    ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device)
    x = torch.stack([torch.cat([token_bos, batch_data[i:i+seq_len-1]]) for i in ix]).long().to(device)
    y = torch.stack([torch.cat([batch_data[i+1:i+seq_len], token_eos]) for i in ix]).long().to(device)
    return x,y

# Define a evaluate loss function to calculate and store training and validation loss for logging and plotting
@torch.no_grad()
def evaluate_loss(model, dataset, args:ModelArgs):
    out = {}
    model.eval()

    for split in ["train", "val"]:
        losses = []
        for _ in range(10):
            xb, yb = get_dataset_batch(dataset, split, args)
            _, loss = model(x=xb, targets=yb)
            losses.append(loss.item())
        out[split] = np.mean(losses)

    model.train()
    return out

# Define a training function to perform model training
def train(model, optimizer, args:ModelArgs):
    print("model: ", model)
    data = loaddata()
    dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device)
    print(f"dataset-shape: {dataset.shape}")
    epochs = args.epochs
    log_interval = args.log_interval
    device = args.device
    losses = []
    start_time = time.time()

    for epoch in range(epochs):
        optimizer.zero_grad()

        xs, ys = get_dataset_batch(dataset, 'train', args)
        xs = xs.to(device)
        ys = ys.to(device)
        logits, loss = model(x=xs, targets=ys)
        loss.backward()
        optimizer.step()

        if epoch % log_interval == 0:
            batch_time = time.time() - start_time
            x = evaluate_loss(model, dataset, args)
            losses += [x]
            print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}")
            start_time = time.time()

    # Print the final validation loss
    print("validation loss: ", losses[-1]['val'])

    # 保存
    save_file = {"model": model.state_dict(),
                    "optimizer": optimizer.state_dict(),
                    "epoch": epoch,
                    "args": args}
    torch.save(save_file, "checkpoints/model_{}.pth".format(epoch))

    # Display the interval losses in plot
    return pd.DataFrame(losses).plot()

## Start training our Llama 3 model
model = Transformer(ModelArgs).to(ModelArgs.device)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer, ModelArgs)
 
# 加载
# checkpoint = torch.load(path, map_location='cpu')
# model.load_state_dict(checkpoint['model'])
# optimizer.load_state_dict(checkpoint['optimizer'])
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
# args.start_epoch = checkpoint['epoch'] + 1