Commit b0010cbc authored by PanZezhong's avatar PanZezhong
Browse files

support qwen2

parent 13f98ed3
...@@ -288,7 +288,9 @@ class JiugeForCauslLM: ...@@ -288,7 +288,9 @@ class JiugeForCauslLM:
with open(os.path.join(model_dir_path, "config.json"), "r") as f: with open(os.path.join(model_dir_path, "config.json"), "r") as f:
config = json.load(f) config = json.load(f)
self.config = config
eos_token_id = self.config["eos_token_id"]
self.eos_token_id = [eos_token_id] if type(eos_token_id) == int else eos_token_id
if "llama" == config["model_type"]: if "llama" == config["model_type"]:
model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half() model = transformers.LlamaForCausalLM.from_pretrained(model_dir_path).cpu().half()
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
...@@ -311,6 +313,8 @@ class JiugeForCauslLM: ...@@ -311,6 +313,8 @@ class JiugeForCauslLM:
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True model_dir_path, trust_remote_code=True
) )
else:
raise ValueError("Unsupported weight naming")
elif "fm9g7b" == config["model_type"]: elif "fm9g7b" == config["model_type"]:
state_dict = torch.load( state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu" os.path.join(model_dir_path, "pytorch_model.bin"), weights_only=True, map_location="cpu"
...@@ -323,7 +327,18 @@ class JiugeForCauslLM: ...@@ -323,7 +327,18 @@ class JiugeForCauslLM:
self.tokenizer = transformers.AutoTokenizer.from_pretrained( self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path, trust_remote_code=True model_dir_path, trust_remote_code=True
) )
else:
raise ValueError("Unsupported weight naming")
elif "qwen2" == config["model_type"]:
state_dict = load_all_safetensors_from_dir(model_dir_path)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
self.meta, LlamaWeightsNaming(), state_dict, ndev=ndev
)
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
model_dir_path
)
else: else:
raise ValueError("Unsupported model architecture") raise ValueError("Unsupported model architecture")
...@@ -391,10 +406,10 @@ class JiugeForCauslLM: ...@@ -391,10 +406,10 @@ class JiugeForCauslLM:
.replace("▁", " ") .replace("▁", " ")
.replace("<0x0A>", "\n") .replace("<0x0A>", "\n")
) )
if output_str.endswith("</s>"):
break
output_content += output_str output_content += output_str
print(output_str, end="", flush=True) print(output_str, end="", flush=True)
if output_tokens[0] in self.eos_token_id:
break
req_pos[0] = req_pos[0] + ntok req_pos[0] = req_pos[0] + ntok
ntok = 1 ntok = 1
tokens = (c_uint * ntok)(*output_tokens) tokens = (c_uint * ntok)(*output_tokens)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment