Unverified Commit 775115e3 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): allow the server to use a local weight cache (#49)

parent 313194f6
......@@ -313,6 +313,12 @@ fn shard_manager(
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
};
// If the WEIGHTS_CACHE_OVERRIDE env var is set, pass it to the shard
// Useful when running inside a HuggingFace Inference Endpoint
if let Ok(weights_cache_override) = env::var("WEIGHTS_CACHE_OVERRIDE") {
env.push(("WEIGHTS_CACHE_OVERRIDE".into(), weights_cache_override.into()));
};
// If the CUDA_VISIBLE_DEVICES env var is set, pass it to the shard
if let Ok(cuda_visible_devices) = env::var("CUDA_VISIBLE_DEVICES") {
env.push(("CUDA_VISIBLE_DEVICES".into(), cuda_visible_devices.into()));
......
......@@ -25,6 +25,7 @@ from transformers.generation.logits_process import (
from text_generation.pb import generate_pb2
WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
class Sampling:
def __init__(self, seed: int, device: str = "cpu"):
......@@ -230,6 +231,9 @@ def try_to_load_from_cache(model_name, revision, filename):
def weight_files(model_name, revision=None, extension=".safetensors"):
"""Get the local safetensors filenames"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension)
files = []
for filename in filenames:
......@@ -249,6 +253,9 @@ def weight_files(model_name, revision=None, extension=".safetensors"):
def download_weights(model_name, revision=None, extension=".safetensors"):
"""Download the safetensors files from the hub"""
if WEIGHTS_CACHE_OVERRIDE is not None:
return list(Path(WEIGHTS_CACHE_OVERRIDE).glob(f"*{extension}"))
filenames = weight_hub_files(model_name, revision, extension)
download_function = partial(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment