Commit 732c7f04 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

9G7B support safetensors

parent c0799551
......@@ -416,11 +416,16 @@ class JiugeForCauslLM:
else:
raise ValueError("Unsupported weight naming")
elif "fm9g7b" == config["model_type"]:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
)
if any(
file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
):
state_dict = load_all_safetensors_from_dir(model_dir_path)
else:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment