Commit 45a7867b authored by Rayyyyy's avatar Rayyyyy
Browse files

update

parent 88c3b719
......@@ -5,7 +5,12 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
model_name_or_path = "deepseek-ai/DeepSeek-V2-Lite-Chat"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
# `device_map` cannot be set to `auto`
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16).cuda()
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
......@@ -15,6 +20,6 @@ messages = [
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt")
outputs = model.generate(input_tensor.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True)
print("result", result)
......@@ -5,24 +5,19 @@ model_name_or_path = "deepseek-ai/DeepSeek-V2-Lite"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
# `max_memory` should be set based on your devices
max_memory = {i: "64GB" for i in range(8)}
# `device_map` cannot be set to `auto`
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=True,
device_map="sequential",
torch_dtype=torch.bfloat16,
max_memory=max_memory,
attn_implementation="eager")
torch_dtype=torch.bfloat16).cuda()
model.generation_config = GenerationConfig.from_pretrained(model_name_or_path)
model.generation_config.pad_token_id = model.generation_config.eos_token_id
text = "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is"
inputs = tokenizer(text, return_tensors="pt")
outputs = model.generate(**inputs.to(model.device), max_new_tokens=100)
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("result", result)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment