Commit 732c7f04 authored by Pan Zezhong's avatar Pan Zezhong
Browse files

9G7B support safetensors

parent c0799551
...@@ -416,11 +416,16 @@ class JiugeForCauslLM: ...@@ -416,11 +416,16 @@ class JiugeForCauslLM:
else: else:
raise ValueError("Unsupported weight naming") raise ValueError("Unsupported weight naming")
elif "fm9g7b" == config["model_type"]: elif "fm9g7b" == config["model_type"]:
state_dict = torch.load( if any(
os.path.join(model_dir_path, "pytorch_model.bin"), file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir()
weights_only=True, ):
map_location="cpu", state_dict = load_all_safetensors_from_dir(model_dir_path)
) else:
state_dict = torch.load(
os.path.join(model_dir_path, "pytorch_model.bin"),
weights_only=True,
map_location="cpu",
)
if LlamaWeightsNaming.match(state_dict): if LlamaWeightsNaming.match(state_dict):
self.meta = JiugeMetaFromLlama(config) self.meta = JiugeMetaFromLlama(config)
self.weights = JiugeWeightsImpl( self.weights = JiugeWeightsImpl(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment