"package/vscode:/vscode.git/clone" did not exist on "40dc810c5e3e790278ece6b36bb9e687b0bc13ff"
sample.py 5.56 KB
Newer Older
chenzk's avatar
v1.0  
chenzk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Use this file to sample from a trained model.
"""
import dataclasses
import json
import os
import time
import torch
from allamo.logging import configure_logger, logger
from allamo.configuration import AllamoConfiguration
from allamo.model.model import AllamoTransformerConfig, AllamoTransformer
from allamo.torch_utils import configure_torch
from allamo.train_utils import remove_unwanted_prefix_from_model_state_dict

class AllamoSampler:

    def __init__(self, config: AllamoConfiguration):
        configure_logger()
        self.config = config
        configure_torch(config)

        if config.init_from == 'resume_last':
            checkpoint_name = 'last_eval_ckpt'
        else:
            checkpoint_name = 'ckpt'
        ckpt_dir = config.checkpoint_path if config.checkpoint_path else config.out_dir
        model_config_fields = [f.name for f in dataclasses.fields(AllamoTransformerConfig)]
        logger.info(f"Loading '{checkpoint_name}' checkpoint files from {ckpt_dir}...")
        with open(os.path.join(ckpt_dir, f'config_{checkpoint_name}.json'), "r", encoding="utf-8") as f:
            config_checkpoint = json.load(f)
        for k in model_config_fields:
            if hasattr(config, k) and k in config_checkpoint['model_args']:
                setattr(config, k, config_checkpoint['model_args'][k])
                
        model_checkpoint = torch.load(os.path.join(ckpt_dir, f'model_{checkpoint_name}.pt'), map_location='cpu')
        self.__load_model(config, config_checkpoint, model_checkpoint, model_config_fields)
        self.__load_tokenizer(config)
        del model_checkpoint
            
    def __load_model(self, config: AllamoConfiguration, config_checkpoint, model_checkpoint, model_config_fields):
        ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'bfloat16-true': torch.bfloat16, 'float16': torch.float16}[config.dtype]
        model_args = {k: getattr(config, k) for k in model_config_fields if hasattr(config, k)}
        modelConf = AllamoTransformerConfig(**model_args)
        model = AllamoTransformer(modelConf)
        
        remove_unwanted_prefix_from_model_state_dict(model_checkpoint)
        model.load_state_dict(model_checkpoint)
        model.eval()
        model.to(device=config.device, dtype=ptdtype)
        if config.compile:
            model = torch.compile(model)
        self.model = model
        logger.info(f"Model loaded from checkpoint")
        if 'iter_num' in config_checkpoint:
            logger.info(f"Last model iteration: {config_checkpoint['iter_num']}")
        
    def __load_tokenizer(self, config: AllamoConfiguration):
        tiktoken_tokenizer_name = config.tiktoken_tokenizer_name
        hf_tokenizer_path = config.hf_tokenizer_path
        if hf_tokenizer_path is not None:
            from transformers import AutoTokenizer
            tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path)
            logger.info(f"HuggingFace tokenizer loaded: {hf_tokenizer_path}")
        elif tiktoken_tokenizer_name is not None:
            import tiktoken
            tokenizer = tiktoken.get_encoding(tiktoken_tokenizer_name)
            logger.info(f"Tiktoken tokenizer loaded: {tiktoken_tokenizer_name}")
        else:
            raise Exception('Tokenizer is not provided. Please specify either a Tiktoken tokenizer or a HuggingFace tokenizer')
        # ensure that the tokenizer and model vocabulary sizes are equal
        # assert len(tokenizer) == self.model.config.vocab_size
        self.tokenizer = tokenizer
    
    def tokenize_prompt(self, text: str):
        return self.tokenizer.encode(text)
        
    def encode_prompt(self, text: str):
        prompt_tokens = self.tokenize_prompt(text)
        prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=self.config.device)
        prompt_tokens = torch.unsqueeze(prompt_tokens, 0)
        return prompt_tokens
        
    def generate_embeddings(self, text: str):
        if text:
            with torch.no_grad():
                prompt_tokens = self.encode_prompt(text)
                embeddings = self.model.generate_embeddings(prompt_tokens)
                embeddings = torch.squeeze(embeddings[:, [-1], :]) # use only the last position
                return embeddings.tolist()
        return []
                
    def generate_completions(self, text: str, samples: int, new_tokens: int, temperature: float, top_k: int):
        result = []
        timer = time.time()
        with torch.no_grad():
            prompt_tokens = self.encode_prompt(text)
            for _ in range(samples):
                y = self.model.generate(prompt_tokens, new_tokens, temperature=temperature, top_k=top_k)
                result.append(self.tokenizer.decode(y[0].tolist()).strip())
        dt = time.time() - timer
        logger.info(f"{new_tokens*samples} completion tokens generated in {dt:.2f}secs ({new_tokens*samples/dt:.2f} tokens/sec) for {prompt_tokens.shape[1]} input tokens")
        print("##################result:", result, "##################")
        return result


if __name__ == '__main__':
    config = AllamoConfiguration()
    sampler = AllamoSampler(config)

    # encode the beginning of the prompt
    if config.prompt.startswith('FILE:'):
        with open(config.prompt[5:], 'r', encoding='utf-8') as f:
            config.prompt = f.read()
            
    completions = sampler.generate_completions(config.prompt, config.num_samples, config.max_new_tokens, temperature=config.temperature, top_k=config.top_k)
    logger.info("Completions:")
    for completion in completions:
        logger.info(completion)
        logger.info('----------------')