Unverified Commit 17d055f5 authored by Benjamin Bartels's avatar Benjamin Bartels Committed by GitHub
Browse files

[Feat] Adds runai distributed streamer (#27230)


Signed-off-by: default avatarbbartels <benjamin@bartels.dev>
Signed-off-by: default avatarBenjamin Bartels <benjamin@bartels.dev>
Co-authored-by: default avataromer-dayan <omdayan@nvidia.com>
Co-authored-by: default avatarCyrus Leung <tlleungac@connect.ust.hk>
parent 2ce5c5d3
......@@ -495,7 +495,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
else \
BITSANDBYTES_VERSION="0.46.1"; \
fi; \
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.14.0'
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm>=1.0.17' 'runai-model-streamer[s3,gcs]>=0.15.0'
ENV VLLM_USAGE_SOURCE production-docker-image
......
......@@ -45,6 +45,15 @@ vllm serve s3://core-llm/Llama-3-8b \
You can tune parameters using `--model-loader-extra-config`:
You can tune `distributed` that controls whether distributed streaming should be used. This is currently only possible on CUDA and ROCM devices. This can significantly improve loading times from object storage or high-throughput network fileshares.
You can read further about Distributed streaming [here](https://github.com/run-ai/runai-model-streamer/blob/master/docs/src/usage.md#distributed-streaming)
```bash
vllm serve /home/meta-llama/Llama-3.2-3B-Instruct \
--load-format runai_streamer \
--model-loader-extra-config '{"distributed":true}'
```
You can tune `concurrency` that controls the level of concurrency and number of OS threads reading tensors from the file to the CPU buffer.
For reading from S3, it will be the number of client instances the host is opening to the S3 server.
......
......@@ -42,6 +42,6 @@ tritonclient==2.51.0
numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.14.0
runai-model-streamer[s3,gcs]==0.15.0
fastsafetensors>=0.1.10
pydantic>=2.12 # 2.11 leads to error on python 3.13
......@@ -12,6 +12,6 @@ tensorizer==2.10.1
packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
runai-model-streamer[s3,gcs]==0.14.0
runai-model-streamer[s3,gcs]==0.15.0
conch-triton-kernels==1.2.1
timm>=1.0.17
......@@ -50,7 +50,7 @@ tritonclient==2.51.0
numba == 0.61.2 # Required for N-gram speculative decoding
numpy
runai-model-streamer[s3,gcs]==0.14.0
runai-model-streamer[s3,gcs]==0.15.0
fastsafetensors>=0.1.10
pydantic>=2.12 # 2.11 leads to error on python 3.13
decord==0.6.0
......
......@@ -965,11 +965,11 @@ rsa==4.9.1
# via google-auth
rtree==1.4.0
# via torchgeo
runai-model-streamer==0.14.0
runai-model-streamer==0.15.0
# via -r requirements/test.in
runai-model-streamer-gcs==0.14.0
runai-model-streamer-gcs==0.15.0
# via runai-model-streamer
runai-model-streamer-s3==0.14.0
runai-model-streamer-s3==0.15.0
# via runai-model-streamer
s3transfer==0.10.3
# via boto3
......
......@@ -712,7 +712,7 @@ setup(
"bench": ["pandas", "matplotlib", "seaborn", "datasets"],
"tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.14.0"],
"runai": ["runai-model-streamer[s3,gcs] >= 0.15.0"],
"audio": [
"librosa",
"soundfile",
......
......@@ -27,9 +27,16 @@ class RunaiModelStreamerLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self._is_distributed = False
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if "distributed" in extra_config and isinstance(
extra_config.get("distributed"), bool
):
self._is_distributed = extra_config.get("distributed")
if "concurrency" in extra_config and isinstance(
extra_config.get("concurrency"), int
):
......@@ -92,8 +99,7 @@ class RunaiModelStreamerLoader(BaseModelLoader):
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
)
def download_model(self, model_config: ModelConfig) -> None:
......
......@@ -657,10 +657,22 @@ def multi_thread_safetensors_weights_iterator(
def runai_safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
is_distributed: bool = False,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
with SafetensorsStreamer() as streamer:
streamer.stream_files(hf_weights_files)
is_cuda_alike = current_platform.is_cuda_alike()
device = (
f"cuda:{current_platform.current_device()}"
if is_distributed and is_cuda_alike
else "cpu"
)
streamer.stream_files(
hf_weights_files,
device=device,
is_distributed=is_distributed,
)
total_tensors = sum(
len(tensors_meta)
for tensors_meta in streamer.files_to_tensors_metadata.values()
......@@ -672,6 +684,7 @@ def runai_safetensors_weights_iterator(
desc="Loading safetensors using Runai Model Streamer",
bar_format=_BAR_FORMAT,
disable=not enable_tqdm(use_tqdm_on_load),
mininterval=2,
)
yield from tensor_iter
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment