Commit b0010cbc authored by PanZezhong's avatar PanZezhong
Browse files

support qwen2

parent 13f98ed3
......@@ -288,7 +288,9 @@ class JiugeForCauslLM:
with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f)
self.config = config
eos_token_id = self.config["eos_token_id"]
self.eos_token_id = [eos_token_id] if type(eos_token_id) == int else eos_token_id
if "llama" == config["model_type"]:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half()
self.meta = JiugeMetaFromLlama(config)
......@@ -311,6 +313,8 @@ class JiugeForCauslLM:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
else:
raise ValueError("Unsupported weight naming")
elif "fm9g7b" == config["model_type"]:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu"
......@@ -323,7 +327,18 @@ class JiugeForCauslLM:
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True
)
else:
raise ValueError("Unsupported weight naming")
elif "qwen2" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path
)
else:
raise ValueError("Unsupported model architecture")
......@@ -391,10 +406,10 @@ class JiugeForCauslLM:
.replace("▁", " ")
.replace("<0x0A>", "\n")
)
if output_str.endswith("</s>"):
break
output_content += output_str
print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
break
req_pos[0] = req_pos[0] + ntok
ntok = 1
tokens = (c_uint * ntok)(*output_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment