import sys import torch from transformers import AutoTokenizer, AutoModelForCausalLM from transformers import LlamaTokenizer, LlamaForCausalLM from fastllm_pytools import torch2flm if __name__ == "__main__": exportPath = sys.argv[1] if (sys.argv[1] is not None) else "nsql-llama-2-7b-fp16.flm"; tokenizer = AutoTokenizer.from_pretrained("models/llama7b/nsql-llama-2-7b") model = AutoModelForCausalLM.from_pretrained("models/llama7b/nsql-llama-2-7b") model.config.model_type = "nsql-llama-2-7b" dtype = sys.argv[2] if len(sys.argv) >= 3 else "float16" exportPath = sys.argv[1] if len(sys.argv) >= 2 else "nsql-llama-2-7b-" + dtype + ".flm" torch2flm.tofile(exportPath, model, tokenizer, dtype = dtype)