import torch

def loaddata(prompts="Hello World", txt='tiny_shakespeare.txt'):
    # Using Tiny Shakespeare dataset for character-level tokenizer. Some part of the following character-level tokenizer is referenced from Andrej karpathy's GitHub (https://github.com/karpathy/nanoGPT/blob/master/data/shakespeare_char/prepare.py) which I found is explained very well.
    # Load tiny_shakespeare data file (https://github.com/tamangmilan/llama3/blob/main/tiny_shakespeare.txt)
    # Load tiny_shakespeare data file.
    with open(txt, 'r') as f:
        data = f.read()
    return data

def vocabulary():
    data = loaddata()
    # Prepare vocabulary by taking all the unique characters from the tiny_shakespeare data
    vocab = sorted(list(set(data)))

    # Training Llama 3 model requires addtional tokens such as <|begin_of_text|>, <|end_of_text|> and <|pad_id|>, we'll add them into vocabulary
    vocab.extend(['<|begin_of_text|>','<|end_of_text|>','<|pad_id|>'])
    vocab_size = len(vocab)
    return vocab

def tokenlizer():
    vocab = vocabulary()
    # Create a mapping between characters with corresponding integer indexes in vocabulary.
    # This is important to build tokenizers encode and decode functions.
    itos = {i:ch for i, ch in enumerate(vocab)}
    stoi = {ch:i for i, ch in enumerate(vocab)}
    # Define tensor token variable to be used later during model training
    # token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int, device=device)
    # token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int, device=device)
    # token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int, device=device)
    token_bos = torch.tensor([stoi['<|begin_of_text|>']], dtype=torch.int).cuda()
    token_eos = torch.tensor([stoi['<|end_of_text|>']], dtype=torch.int).cuda()
    token_pad = torch.tensor([stoi['<|pad_id|>']], dtype=torch.int).cuda()
    return stoi, itos, token_bos, token_eos, token_pad

# Tokenizers encode function: take a string, output a list of integers
def encode(s):
    vocab = vocabulary()
    stoi = {ch:i for i, ch in enumerate(vocab)}
    return [stoi[ch] for ch in s]

# Tokenizers decode function: take a list of integers, output a string
def decode(l):
    stoi, itos, token_bos, token_eos, token_pad = tokenlizer()
    return ''.join(itos[i] for i in l)

# prompts = "Hello World"
# encoded_tokens = encode(prompts)
# decoded_text = decode(encoded_tokens)