import os import sys import fire import warnings from typing import List, Optional from llama import Dialog, Llama warnings.filterwarnings('ignore', category=UserWarning) def main( ckpt_dir: str, tokenizer_path: str, temperature: float = 0.6, top_p: float = 0.9, max_seq_len: int = 512, max_batch_size: int = 4, max_gen_len: Optional[int] = None, ): generator = Llama.build( ckpt_dir=ckpt_dir, tokenizer_path=tokenizer_path, max_seq_len=max_seq_len, max_batch_size=max_batch_size, ) dialogs: List[Dialog] = [] # Start with an empty dialog try: # Continue util the user decides to stop while True: local_rank = int(os.environ.get("LOCAL_RANK", 0)) if local_rank > 0: dialogs.append({"role": "user", "content": "tmplate"}) else: user_input = input("You: ") # Allow the user to quit the dialogue if user_input.lower() in ['stop', 'exit']: break dialogs.append({"role": "user", "content": user_input}) # Generate response based on the current dialog context results = generator.chat_completion( [dialogs], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, ) response = results[0]['generation']['content'] print(f"Assistant: {response}\n") # Append the generated response to the dialog dialogs.append({"role": "assistant", "content": response}) except KeyboardInterrupt: print("Exiting dialogue.") if __name__ == "__main__": fire.Fire(main)