llama3_chat.py 1.42 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import sys
import fire

from typing import List, Optional
from llama import Dialog, Llama


def main(
    ckpt_dir: str,
    tokenizer_path: str,
    temperature: float = 0.6,
    top_p: float = 0.9,
    max_seq_len: int = 512,
    max_batch_size: int = 4,
    max_gen_len: Optional[int] = None,
):
    generator = Llama.build(
        ckpt_dir=ckpt_dir,
        tokenizer_path=tokenizer_path,
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
    )
    dialogs: List[Dialog] = [[]] # Start with an empty dialog
    try:
        # Continue util the user decides to stop
        while True:
            user_input = input("You: ")
            # Allow the user to quit the dialogue
            if user_input.lower() in ['stop', 'exit']:
                break
            dialogs[0].append({"role": "user", "content": user_input})
            # Generate response based on the current dialog context
            results  = generator.chat_completion(
                [dialogs],
                max_gen_len=max_gen_len,
                temperature=temperature,
                top_p=top_p,)[0]
            response = results['generation']['content']
            print(f"Assistant: {response}\n")
            # Append the generated response to the dialog
            dialogs[0].append({"role": "assistant", "content": response})
    except KeyboardInterrupt:
        print("Exiting dialogue.")

if __name__ == "__main__":
    fire.Fire(main)