"vscode:/vscode.git/clone" did not exist on "2dffac464c82ac7c509c78f7d12a7c72ea765a63"
Unverified Commit 1937e298 authored by Aurick Qiao's avatar Aurick Qiao Committed by GitHub
Browse files

[Core] Sharded State Loader download from HF (#4889)

parent f0eecee6
......@@ -423,6 +423,16 @@ class ShardedStateLoader(BaseModelLoader):
result[k] = t
return result
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]):
if os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(model_name_or_path,
self.load_config.download_dir,
allow_patterns, revision)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
......@@ -433,6 +443,10 @@ class ShardedStateLoader(BaseModelLoader):
from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights(model_config.model,
model_config.revision)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
......@@ -440,7 +454,7 @@ class ShardedStateLoader(BaseModelLoader):
cache_config)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
model_config.model,
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = glob.glob(pattern)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment