Unverified Commit 754b699c authored by Jon Gill's avatar Jon Gill Committed by GitHub
Browse files

[Bug]: Fix S3 model/tokenizer path resolution (#18083)


Signed-off-by: default avatarJon Gill <jon@yurts.ai>
parent 6e27c6d8
...@@ -611,23 +611,30 @@ class ModelConfig: ...@@ -611,23 +611,30 @@ class ModelConfig:
def maybe_pull_model_tokenizer_for_s3(self, model: str, def maybe_pull_model_tokenizer_for_s3(self, model: str,
tokenizer: str) -> None: tokenizer: str) -> None:
""" """Pull model/tokenizer from S3 to temporary directory when needed.
Pull the model config or tokenizer to a temporary
directory in case of S3.
Args: Args:
model: The model name or path. model: Model name or path
tokenizer: The tokenizer name or path. tokenizer: Tokenizer name or path
""" """
if is_s3(model) or is_s3(tokenizer): if not (is_s3(model) or is_s3(tokenizer)):
return
if is_s3(model): if is_s3(model):
s3_model = S3Model() s3_model = S3Model()
s3_model.pull_files( s3_model.pull_files(model,
model, allow_pattern=["*.model", "*.py", "*.json"]) allow_pattern=["*.model", "*.py", "*.json"])
self.model_weights = self.model self.model_weights = model
self.model = s3_model.dir self.model = s3_model.dir
# If tokenizer is same as model, download to same directory
if model == tokenizer:
s3_model.pull_files(
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = s3_model.dir
return
# Only download tokenizer if needed and not already handled
if is_s3(tokenizer): if is_s3(tokenizer):
s3_tokenizer = S3Model() s3_tokenizer = S3Model()
s3_tokenizer.pull_files( s3_tokenizer.pull_files(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment