import torch from torch import nn from torch.nn import functional as F from model import Transformer from dataset import loaddata, tokenlizer, encode, decode from config import ModelArgs ## Inference Llama 3 Model # This function generates text sequences based on provided prompts using the LLama 3 model we've built and trained. def generate(model, prompts: str, params: ModelArgs, max_gen_len: int=500, temperature: float = 0.6, top_p: float = 0.9): # prompt_tokens: List of user input texts or prompts # max_gen_len: Maximum length of the generated text sequence. # temperature: Temperature value for controlling randomness in sampling. Defaults to 0.6. # top_p: Top-p probability threshold for sampling prob output from the logits. Defaults to 0.9. # prompt_tokens = [0] bsz = 1 #For inferencing, in general user just input one prompt which we'll take it as 1-batch stoi, itos, token_bos, token_eos, token_pad = tokenlizer() prompt_tokens = token_bos.tolist() + encode(prompts) assert len(prompt_tokens) <= params.max_seq_len, "prompt token length should be small than max_seq_len" total_len = min(len(prompt_tokens)+max_gen_len, params.max_seq_len) # this tokens matrix is to store the input prompts and all the output that is generated by model. # later we'll use the tokenizers decode function to decode this token to view results in text format tokens = torch.full((bsz,total_len), fill_value=token_pad.item(), dtype=torch.long, device=params.device) # fill in the prompt tokens into the token matrix tokens[:,:len(prompt_tokens)] = torch.tensor(prompt_tokens, dtype=torch.long, device=params.device) #create a prompt_mask_token for later use to identify if the token is a prompt token or a padding token # True if it is a prompt token, False if it is a padding token input_text_mask = tokens != token_pad.item() #now we can start inferencing using one token at a time from the prompt_tokens list starting with the first position. prev_pos = 0 for cur_pos in range(1, total_len): with torch.no_grad(): logits, _ = model(x=tokens[:,prev_pos:cur_pos], start_pos=prev_pos) if temperature > 0: probs = torch.softmax(logits[:, -1]/temperature, dim=-1) next_token = sample_top_p(probs, top_p) else: next_token = torch.argmax(logits[:, -1], dim=-1) next_token = next_token.reshape(-1) # only replace the token if it's a padding token next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) tokens[:, cur_pos] = next_token prev_pos = cur_pos if tokens[:,cur_pos]==token_pad.item() and next_token == token_eos.item(): break output_tokens, output_texts = [], [] for i, toks in enumerate(tokens.tolist()): # eos_idx = toks.index(token_eos.item()) if token_eos.item() in toks: eos_idx = toks.index(token_eos.item()) toks = toks[:eos_idx] output_tokens.append(toks) output_texts.append(decode(toks)) return output_tokens, output_texts # Perform top-p (nucleus) sampling on a probability distribution. # probs (torch.Tensor): Probability distribution tensor derived from the logits. # p: Probability threshold for top-p sampling. # According to the paper, Top-p sampling selects the smallest set of tokens whose cumulative probability mass exceeds the threshold p. # The distribution is renormalized based on the selected tokens. def sample_top_p(probs, p): probs_sort, prob_idx = torch.sort(probs, dim=-1, descending=True) probs_sum = torch.cumsum(probs_sort, dim=-1) mask = probs_sum - probs_sort > p probs_sort[mask] = 0.0 probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) next_token = torch.multinomial(probs_sort, num_samples=1) next_token = torch.gather(prob_idx, -1, next_token) # Sampled token indices from the vocabular is returned return next_token def load_ckpt(model, args:ModelArgs): epoch = 2499 path = "checkpoints/model_{}.pth".format(epoch) checkpoint = torch.load(path, map_location='cpu') model.load_state_dict(checkpoint['model']) device = args.device model.to(device) model.eval() return model ## Perform the inferencing on user input prompts # prompts = "Consider you what services he has done" prompts = "Would you proceed especially against Caius Marcius?" model = Transformer(ModelArgs).to(ModelArgs.device) model = load_ckpt(model, ModelArgs) output_tokens, output_texts = generate(model, prompts, ModelArgs) output_texts = output_texts[0].replace("<|begin_of_text|>", "") print("output: ", output_texts)