import torch from torch import nn from torch.nn import functional as F import numpy as np import time import pandas as pd from matplotlib import pyplot as plt from dataset import loaddata, tokenlizer, encode, decode from model import Transformer from config import ModelArgs ## Train Llama 3 Model # Define function to generate batches from the given dataset def get_dataset_batch(data, split, args:ModelArgs): seq_len = args.max_seq_len batch_size = args.max_batch_size device = args.device train = data[:int(0.8 * len(data))] val = data[int(0.8 * len(data)): int(0.9 * len(data))] test = data[int(0.9 * len(data)):] batch_data = train if split == "val": batch_data = val if split == "test": batch_data = test # Picking random starting points from the dataset to give random samples for training, validation and testing. stoi, itos, token_bos, token_eos, token_pad = tokenlizer() ix = torch.randint(0, len(batch_data) - seq_len - 3, (batch_size,)).to(device) x = torch.stack([torch.cat([token_bos, batch_data[i:i+seq_len-1]]) for i in ix]).long().to(device) y = torch.stack([torch.cat([batch_data[i+1:i+seq_len], token_eos]) for i in ix]).long().to(device) return x,y # Define a evaluate loss function to calculate and store training and validation loss for logging and plotting @torch.no_grad() def evaluate_loss(model, dataset, args:ModelArgs): out = {} model.eval() for split in ["train", "val"]: losses = [] for _ in range(10): xb, yb = get_dataset_batch(dataset, split, args) _, loss = model(x=xb, targets=yb) losses.append(loss.item()) out[split] = np.mean(losses) model.train() return out # Define a training function to perform model training def train(model, optimizer, args:ModelArgs): print("model: ", model) data = loaddata() dataset = torch.tensor(encode(data), dtype=torch.int).to(ModelArgs.device) print(f"dataset-shape: {dataset.shape}") epochs = args.epochs log_interval = args.log_interval device = args.device losses = [] start_time = time.time() for epoch in range(epochs): optimizer.zero_grad() xs, ys = get_dataset_batch(dataset, 'train', args) xs = xs.to(device) ys = ys.to(device) logits, loss = model(x=xs, targets=ys) loss.backward() optimizer.step() if epoch % log_interval == 0: batch_time = time.time() - start_time x = evaluate_loss(model, dataset, args) losses += [x] print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f}") start_time = time.time() # Print the final validation loss print("validation loss: ", losses[-1]['val']) # 保存 save_file = {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args} torch.save(save_file, "checkpoints/model_{}.pth".format(epoch)) # Display the interval losses in plot return pd.DataFrame(losses).plot() ## Start training our Llama 3 model model = Transformer(ModelArgs).to(ModelArgs.device) optimizer = torch.optim.Adam(model.parameters()) train(model, optimizer, ModelArgs) # 加载 # checkpoint = torch.load(path, map_location='cpu') # model.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) # lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) # args.start_epoch = checkpoint['epoch'] + 1