Unverified Commit 04a49669 authored by omer-dayan's avatar omer-dayan Committed by GitHub
Browse files

RayLLM Bugfix - Preserve obj store URL for multi engine_config creation (#30803)


Signed-off-by: default avatarOmer Dayan <omdayan@nvidia.com>
Signed-off-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <2037008807@qq.com>
Co-authored-by: default avatarIsotr0py <mozf@mail2.sysu.edu.cn>
parent 96fcd3c2
...@@ -106,6 +106,10 @@ class ModelConfig: ...@@ -106,6 +106,10 @@ class ModelConfig:
"""Name or path of the Hugging Face model to use. It is also used as the """Name or path of the Hugging Face model to use. It is also used as the
content for `model_name` tag in metrics output when `served_model_name` is content for `model_name` tag in metrics output when `served_model_name` is
not specified.""" not specified."""
model_weights: str = ""
"""Original model weights path. Used when the model is pulled from object
storage (e.g., RunAI) to preserve the original URI while `model` points to
the local directory."""
runner: RunnerOption = "auto" runner: RunnerOption = "auto"
"""The type of model runner to use. Each vLLM instance only supports one """The type of model runner to use. Each vLLM instance only supports one
model runner, even if the same model can be used for multiple types.""" model runner, even if the same model can be used for multiple types."""
...@@ -705,6 +709,10 @@ class ModelConfig: ...@@ -705,6 +709,10 @@ class ModelConfig:
tokenizer: Tokenizer name or path tokenizer: Tokenizer name or path
""" """
# Skip if model_weights is already set (model already pulled)
if self.model_weights:
return
if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)): if not (is_runai_obj_uri(model) or is_runai_obj_uri(tokenizer)):
return return
......
...@@ -354,6 +354,7 @@ class EngineArgs: ...@@ -354,6 +354,7 @@ class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str = ModelConfig.model model: str = ModelConfig.model
model_weights: str = ModelConfig.model_weights
served_model_name: str | list[str] | None = ModelConfig.served_model_name served_model_name: str | list[str] | None = ModelConfig.served_model_name
tokenizer: str | None = ModelConfig.tokenizer tokenizer: str | None = ModelConfig.tokenizer
hf_config_path: str | None = ModelConfig.hf_config_path hf_config_path: str | None = ModelConfig.hf_config_path
...@@ -1206,6 +1207,7 @@ class EngineArgs: ...@@ -1206,6 +1207,7 @@ class EngineArgs:
return ModelConfig( return ModelConfig(
model=self.model, model=self.model,
model_weights=self.model_weights,
hf_config_path=self.hf_config_path, hf_config_path=self.hf_config_path,
runner=self.runner, runner=self.runner,
convert=self.convert, convert=self.convert,
...@@ -1349,6 +1351,7 @@ class EngineArgs: ...@@ -1349,6 +1351,7 @@ class EngineArgs:
model_config = self.create_model_config() model_config = self.create_model_config()
self.model = model_config.model self.model = model_config.model
self.model_weights = model_config.model_weights
self.tokenizer = model_config.tokenizer self.tokenizer = model_config.tokenizer
self._check_feature_supported(model_config) self._check_feature_supported(model_config)
......
...@@ -108,8 +108,8 @@ class RunaiModelStreamerLoader(BaseModelLoader): ...@@ -108,8 +108,8 @@ class RunaiModelStreamerLoader(BaseModelLoader):
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"""Load weights into a model.""" """Load weights into a model."""
model_weights = model_config.model model_weights = model_config.model
if hasattr(model_config, "model_weights"): if model_weights_override := model_config.model_weights:
model_weights = model_config.model_weights model_weights = model_weights_override
model.load_weights( model.load_weights(
self._get_weights_iterator(model_weights, model_config.revision) self._get_weights_iterator(model_weights, model_config.revision)
) )
...@@ -110,8 +110,8 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -110,8 +110,8 @@ class ShardedStateLoader(BaseModelLoader):
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
model_weights = model_config.model model_weights = model_config.model
if hasattr(model_config, "model_weights"): if model_weights_override := model_config.model_weights:
model_weights = model_config.model_weights model_weights = model_weights_override
local_model_path = model_weights local_model_path = model_weights
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment