Unverified Commit 2a101207 authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

fix(server): Handle loading from local files for MPT (#534)

This PR allows the MPT model to be loaded from local files. Without this
change, an exception will be thrown by `hf_hub_download` function if
`model_id` is a local path.
parent e6888d0e
import torch import torch
import torch.distributed import torch.distributed
from pathlib import Path
from typing import Optional, Type from typing import Optional, Type
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase from transformers import AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
...@@ -60,6 +61,11 @@ class MPTSharded(CausalLM): ...@@ -60,6 +61,11 @@ class MPTSharded(CausalLM):
) )
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
# If model_id is a local path, load the file directly
local_path = Path(model_id, "config.json")
if local_path.exists():
filename = str(local_path.resolve())
else:
filename = hf_hub_download(model_id, revision=revision, filename="config.json") filename = hf_hub_download(model_id, revision=revision, filename="config.json")
with open(filename, "r") as f: with open(filename, "r") as f:
config = json.load(f) config = json.load(f)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment