Unverified Commit 17bc841b authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): enable hf-transfer (#76)

parent 6796d38c
...@@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \ ...@@ -33,6 +33,7 @@ ENV LANG=C.UTF-8 \
LC_ALL=C.UTF-8 \ LC_ALL=C.UTF-8 \
DEBIAN_FRONTEND=noninteractive \ DEBIAN_FRONTEND=noninteractive \
HUGGINGFACE_HUB_CACHE=/data \ HUGGINGFACE_HUB_CACHE=/data \
HF_HUB_ENABLE_HF_TRANSFER=1 \
MODEL_ID=bigscience/bloom-560m \ MODEL_ID=bigscience/bloom-560m \
QUANTIZE=false \ QUANTIZE=false \
NUM_SHARD=1 \ NUM_SHARD=1 \
......
...@@ -98,23 +98,18 @@ fn main() -> ExitCode { ...@@ -98,23 +98,18 @@ fn main() -> ExitCode {
}) })
.expect("Error setting Ctrl-C handler"); .expect("Error setting Ctrl-C handler");
// Download weights // Download weights for sharded models
if weights_cache_override.is_none() { if weights_cache_override.is_none() && num_shard > 1 {
let mut download_argv = vec![ let mut download_argv = vec![
"text-generation-server".to_string(), "text-generation-server".to_string(),
"download-weights".to_string(), "download-weights".to_string(),
model_id.clone(), model_id.clone(),
"--extension".to_string(),
".safetensors".to_string(),
"--logger-level".to_string(), "--logger-level".to_string(),
"INFO".to_string(), "INFO".to_string(),
"--json-output".to_string(), "--json-output".to_string(),
]; ];
if num_shard == 1 {
download_argv.push("--extension".to_string());
download_argv.push(".bin".to_string());
} else {
download_argv.push("--extension".to_string());
download_argv.push(".safetensors".to_string());
}
// Model optional revision // Model optional revision
if let Some(ref revision) = revision { if let Some(ref revision) = revision {
...@@ -131,6 +126,9 @@ fn main() -> ExitCode { ...@@ -131,6 +126,9 @@ fn main() -> ExitCode {
env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into())); env.push(("HUGGINGFACE_HUB_CACHE".into(), huggingface_hub_cache.into()));
}; };
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// Start process // Start process
tracing::info!("Starting download process."); tracing::info!("Starting download process.");
let mut download_process = match Popen::create( let mut download_process = match Popen::create(
...@@ -209,12 +207,6 @@ fn main() -> ExitCode { ...@@ -209,12 +207,6 @@ fn main() -> ExitCode {
} }
sleep(Duration::from_millis(100)); sleep(Duration::from_millis(100));
} }
} else {
tracing::info!(
"weights_cache_override is set to {:?}.",
weights_cache_override
);
tracing::info!("Skipping download.")
} }
// Shared shutdown bool // Shared shutdown bool
...@@ -479,6 +471,9 @@ fn shard_manager( ...@@ -479,6 +471,9 @@ fn shard_manager(
// Safetensors load fast // Safetensors load fast
env.push(("SAFETENSORS_FAST_GPU".into(), "1".into())); env.push(("SAFETENSORS_FAST_GPU".into(), "1".into()));
// Enable hf transfer for insane download speeds
env.push(("HF_HUB_ENABLE_HF_TRANSFER".into(), "1".into()));
// If huggingface_hub_cache is some, pass it to the shard // If huggingface_hub_cache is some, pass it to the shard
// Useful when running inside a docker container // Useful when running inside a docker container
if let Some(huggingface_hub_cache) = huggingface_hub_cache { if let Some(huggingface_hub_cache) = huggingface_hub_cache {
......
...@@ -192,6 +192,14 @@ grpcio = ">=1.51.1" ...@@ -192,6 +192,14 @@ grpcio = ">=1.51.1"
protobuf = ">=4.21.6,<5.0dev" protobuf = ">=4.21.6,<5.0dev"
setuptools = "*" setuptools = "*"
[[package]]
name = "hf-transfer"
version = "0.1.0"
description = ""
category = "main"
optional = false
python-versions = ">=3.7"
[[package]] [[package]]
name = "idna" name = "idna"
version = "3.4" version = "3.4"
...@@ -622,7 +630,7 @@ bnb = ["bitsandbytes"] ...@@ -622,7 +630,7 @@ bnb = ["bitsandbytes"]
[metadata] [metadata]
lock-version = "1.1" lock-version = "1.1"
python-versions = "^3.9" python-versions = "^3.9"
content-hash = "f3cab6881b52045770a90ec9be7415a0ee499d9e980892d544f68073700cf321" content-hash = "ef6da62cff76be3eeb45eac98326d6e4fac5d35796b8bdcf555575323ce97ba2"
[metadata.files] [metadata.files]
accelerate = [ accelerate = [
...@@ -861,6 +869,28 @@ grpcio-tools = [ ...@@ -861,6 +869,28 @@ grpcio-tools = [
{file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"}, {file = "grpcio_tools-1.51.1-cp39-cp39-win32.whl", hash = "sha256:40ef70e8c5d0310dedff9af502b520b4c7e215bce94094527fb959150a0c594a"},
{file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"}, {file = "grpcio_tools-1.51.1-cp39-cp39-win_amd64.whl", hash = "sha256:15b8acf4eaa0ebe37e2f69108de49efd935b7abe9c7e58ba737490b99906aa76"},
] ]
hf-transfer = [
{file = "hf_transfer-0.1.0-cp310-cp310-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0f41bb04898d041b774220048f237d10560ec27e1decd01a04d323c64202e8fe"},
{file = "hf_transfer-0.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:94510d4e3a66aa5afa06b61ff537895c3f1b93d689575a8a840ea9ec3189d3d8"},
{file = "hf_transfer-0.1.0-cp310-none-win_amd64.whl", hash = "sha256:e85134084c7e9e9daa74331c4690f821a30afedacb97355d1a66c243317c3e7a"},
{file = "hf_transfer-0.1.0-cp311-cp311-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3441b0cba24afad7fffbcfc0eb7e31d3df127d092feeee73e5b08bb5752c903b"},
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14938d44c71a07b452612a90499b5f021b2277e1c93c66f60d06d746b2d0661d"},
{file = "hf_transfer-0.1.0-cp311-cp311-manylinux_2_34_x86_64.whl", hash = "sha256:ae01b9844995622beee0f1b7ff0240e269bfc28ea46149eb4abbf63b4683f3e2"},
{file = "hf_transfer-0.1.0-cp311-none-win_amd64.whl", hash = "sha256:79e9505bffd3a1086be13033a805c8e6f4bb763de03a4197b959984def587e7f"},
{file = "hf_transfer-0.1.0-cp37-cp37m-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:5656deb183e271d37925de0d6989d7a4b1eefae42d771f10907f41fce08bdada"},
{file = "hf_transfer-0.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bcb16ec07f1ad343b7189745b6f659b7f82b864a55d3b84fe34361f64e4abc76"},
{file = "hf_transfer-0.1.0-cp37-none-win_amd64.whl", hash = "sha256:5314b708bc2a8cf844885d350cd13ba0b528466d3eb9766e4d8d39d080e718c0"},
{file = "hf_transfer-0.1.0-cp38-cp38-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:fe245d0d84bbc113870144c56b50425f9b6cacc3e361b3559b0786ac076ba260"},
{file = "hf_transfer-0.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49099d41c05a19202dca0306bfa7f42cedaea57ccc783641b1533de860b6a1f4"},
{file = "hf_transfer-0.1.0-cp38-none-win_amd64.whl", hash = "sha256:0d7bb607a7372908ffa2d55f1e6790430c5706434c2d1d664db4928730c2c7e4"},
{file = "hf_transfer-0.1.0-cp39-cp39-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:70f40927bc4a19ab50605bb542bd3858eb465ad65c94cfcaf36cf36d68fc5169"},
{file = "hf_transfer-0.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d92203155f451a9b517267d2a0966b282615037e286ff0420a2963f67d451de3"},
{file = "hf_transfer-0.1.0-cp39-none-win_amd64.whl", hash = "sha256:c1c2799154e4bd03d2b2e2907d494005f707686b80d5aa4421c859ffa612ffa3"},
{file = "hf_transfer-0.1.0-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b0900bd5698a77fb44c95639eb3ec97202d13e1bd4282cde5c81ed48e3f9341e"},
{file = "hf_transfer-0.1.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:427032df4e83a1bedaa76383a5b825cf779fdfc206681c28b21476bc84089280"},
{file = "hf_transfer-0.1.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e36fe371af6e31795621ffab53a2dcb107088fcfeb2951eaa2a10bfdc8b8863b"},
{file = "hf_transfer-0.1.0.tar.gz", hash = "sha256:f692ef717ded50e441b4d40b6aea625772f68b90414aeef86bb33eab40cb09a4"},
]
idna = [ idna = [
{file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"},
{file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"},
......
...@@ -22,6 +22,7 @@ loguru = "^0.6.0" ...@@ -22,6 +22,7 @@ loguru = "^0.6.0"
opentelemetry-api = "^1.15.0" opentelemetry-api = "^1.15.0"
opentelemetry-exporter-otlp = "^1.15.0" opentelemetry-exporter-otlp = "^1.15.0"
opentelemetry-instrumentation-grpc = "^0.36b0" opentelemetry-instrumentation-grpc = "^0.36b0"
hf-transfer = "^0.1.0"
[tool.poetry.extras] [tool.poetry.extras]
bnb = ["bitsandbytes"] bnb = ["bitsandbytes"]
......
import time import time
import concurrent
import os import os
from concurrent.futures import ThreadPoolExecutor
from datetime import timedelta from datetime import timedelta
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
...@@ -147,20 +145,17 @@ def download_weights( ...@@ -147,20 +145,17 @@ def download_weights(
) )
return Path(local_file) return Path(local_file)
executor = ThreadPoolExecutor(max_workers=5)
futures = [
executor.submit(download_file, filename=filename) for filename in filenames
]
# We do this instead of using tqdm because we want to parse the logs with the launcher # We do this instead of using tqdm because we want to parse the logs with the launcher
start_time = time.time() start_time = time.time()
files = [] files = []
for i, future in enumerate(concurrent.futures.as_completed(futures)): for i, filename in enumerate(filenames):
file = download_file(filename)
elapsed = timedelta(seconds=int(time.time() - start_time)) elapsed = timedelta(seconds=int(time.time() - start_time))
remaining = len(futures) - (i + 1) remaining = len(filenames) - (i + 1)
eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0 eta = (elapsed / (i + 1)) * remaining if remaining > 0 else 0
logger.info(f"Download: [{i + 1}/{len(futures)}] -- ETA: {eta}") logger.info(f"Download: [{i + 1}/{len(filenames)}] -- ETA: {eta}")
files.append(future.result()) files.append(file)
return files return files
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment