Unverified Commit 71ce4404 authored by omer-dayan's avatar omer-dayan Committed by GitHub
Browse files

Support S3 Sharded loading with RunAI Model Streamer (#16317)


Signed-off-by: default avatarOmer Dayan (SW-GPU) <omer@run.ai>
Co-authored-by: default avatarCyrus Leung <cyrus.tl.leung@gmail.com>
parent 188b7f9b
......@@ -1489,6 +1489,7 @@ class LoadFormat(str, enum.Enum):
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors"
......
......@@ -611,8 +611,12 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, load_config: LoadConfig):
def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
......@@ -659,7 +663,7 @@ class ShardedStateLoader(BaseModelLoader):
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]):
if os.path.isdir(model_name_or_path):
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
......@@ -678,12 +682,13 @@ class ShardedStateLoader(BaseModelLoader):
device_config = vllm_config.device_config
model_config = vllm_config.model_config
target_device = torch.device(device_config.device)
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(model_config.model,
model_config.revision)
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights
with set_default_torch_dtype(model_config.dtype):
with target_device:
......@@ -695,6 +700,13 @@ class ShardedStateLoader(BaseModelLoader):
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part=" * ")}"
filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
......@@ -702,10 +714,7 @@ class ShardedStateLoader(BaseModelLoader):
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict())
for path in filepaths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
......@@ -729,6 +738,18 @@ class ShardedStateLoader(BaseModelLoader):
f"Missing keys {tuple(state_dict)} in loaded state!")
return model.eval()
def iterate_over_files(
self, paths) -> Generator[Tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod
def save_model(
model: torch.nn.Module,
......@@ -1515,4 +1536,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment