Unverified Commit 2a101207 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.
parent e6888d0e
import torch
import torch.distributed
from pathlib import Path
from typing import Optional, Type
from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
......@@ -60,7 +61,12 @@ class MPTSharded(CausalLM):
)
tokenizer.pad_token = tokenizer.eos_token
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f:
config = json.load(f)
config = PretrainedConfig(**config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment