llama3_chat.py 1.71 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
import os
Rayyyyy's avatar
Rayyyyy committed
2
3
import sys
import fire
Rayyyyy's avatar
Rayyyyy committed
4
import warnings
Rayyyyy's avatar
Rayyyyy committed
5
6
7
8

from typing import List, Optional
from llama import Dialog, Llama

Rayyyyy's avatar
Rayyyyy committed
9
warnings.filterwarnings('ignore', category=UserWarning)
Rayyyyy's avatar
Rayyyyy committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25

def main(
    ckpt_dir: str,
    tokenizer_path: str,
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_seq_len: int = 512,
    max_batch_size: int = 4,
    max_gen_len: Optional[int] = None,
):
    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )
Rayyyyy's avatar
Rayyyyy committed
26
    dialogs: List[Dialog] = [] # Start with an empty dialog
Rayyyyy's avatar
Rayyyyy committed
27
28
29
    try:
        # Continue util the user decides to stop
        while True:
Rayyyyy's avatar
Rayyyyy committed
30
31
            local_rank = int(os.environ.get("LOCAL_RANK", 0))
            if local_rank > 0:
Rayyyyy's avatar
Rayyyyy committed
32
                dialogs.append({"role": "user", "content": "tmplate"})
Rayyyyy's avatar
Rayyyyy committed
33
34
35
36
37
38
            else:
                user_input = input("You: ")
                # Allow the user to quit the dialogue
                if user_input.lower() in ['stop', 'exit']:
                    break
                dialogs.append({"role": "user", "content": user_input})
Rayyyyy's avatar
Rayyyyy committed
39
40
41
42
43
            # Generate response based on the current dialog context
            results  = generator.chat_completion(
                [dialogs],
                max_gen_len=max_gen_len,
                temperature=temperature,
Rayyyyy's avatar
Rayyyyy committed
44
45
46
                top_p=top_p,
            )
            response = results[0]['generation']['content']
Rayyyyy's avatar
Rayyyyy committed
47
48
            print(f"Assistant: {response}\n")
            # Append the generated response to the dialog
Rayyyyy's avatar
Rayyyyy committed
49
            dialogs.append({"role": "assistant", "content": response})
Rayyyyy's avatar
Rayyyyy committed
50
51
52
    except KeyboardInterrupt:
        print("Exiting dialogue.")

Rayyyyy's avatar
Rayyyyy committed
53

Rayyyyy's avatar
Rayyyyy committed
54
55
if __name__ == "__main__":
    fire.Fire(main)