Unverified Commit b194557a authored by Benjamin Bartels's avatar Benjamin Bartels Committed by GitHub
Browse files

Adds parallel model weight loading for runai_streamer (#21330)


Signed-off-by: default avatarbbartels <benjamin@bartels.dev>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 774d0c01
...@@ -659,7 +659,8 @@ setup( ...@@ -659,7 +659,8 @@ setup(
"bench": ["pandas", "datasets"], "bench": ["pandas", "datasets"],
"tensorizer": ["tensorizer==2.10.1"], "tensorizer": ["tensorizer==2.10.1"],
"fastsafetensors": ["fastsafetensors >= 0.1.10"], "fastsafetensors": ["fastsafetensors >= 0.1.10"],
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"], "runai":
["runai-model-streamer >= 0.13.3", "runai-model-streamer-s3", "boto3"],
"audio": ["librosa", "soundfile", "audio": ["librosa", "soundfile",
"mistral_common[audio]"], # Required for audio processing "mistral_common[audio]"], # Required for audio processing
"video": [] # Kept for backwards compatibility "video": [] # Kept for backwards compatibility
......
...@@ -482,14 +482,20 @@ def runai_safetensors_weights_iterator( ...@@ -482,14 +482,20 @@ def runai_safetensors_weights_iterator(
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.""" """Iterate over the weights in the model safetensor files."""
with SafetensorsStreamer() as streamer: with SafetensorsStreamer() as streamer:
for st_file in tqdm( streamer.stream_files(hf_weights_files)
hf_weights_files, total_tensors = sum(
desc="Loading safetensors using Runai Model Streamer", len(tensors_meta)
disable=not enable_tqdm(use_tqdm_on_load), for tensors_meta in streamer.files_to_tensors_metadata.values())
bar_format=_BAR_FORMAT,
): tensor_iter = tqdm(
streamer.stream_file(st_file) streamer.get_tensors(),
yield from streamer.get_tensors() total=total_tensors,
desc="Loading safetensors using Runai Model Streamer",
bar_format=_BAR_FORMAT,
disable=not enable_tqdm(use_tqdm_on_load),
)
yield from tensor_iter
def fastsafetensors_weights_iterator( def fastsafetensors_weights_iterator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment