"vscode:/vscode.git/clone" did not exist on "017fedcc415ee760e38d35830b4fc571092a0200"
Unverified Commit a38c1497 authored by wangyu's avatar wangyu Committed by GitHub
Browse files

feat(draft_model): support draft_model for RemoteModelLoader (#6407)


Signed-off-by: default avatarwangyu <wangyu.steph@bytedance.com>
parent 74dd4249
...@@ -34,6 +34,12 @@ parser.add_argument( ...@@ -34,6 +34,12 @@ parser.add_argument(
type=str, type=str,
help="remote address to store model weights", help="remote address to store model weights",
) )
parser.add_argument(
"--remote-draft-model-save-url",
default=None,
type=str,
help="remote address to store draft model weights",
)
def main(args): def main(args):
...@@ -43,7 +49,10 @@ def main(args): ...@@ -43,7 +49,10 @@ def main(args):
raise ValueError("model path must be a local directory") raise ValueError("model path must be a local directory")
# Create LLM instance from arguments # Create LLM instance from arguments
llm = Engine(**dataclasses.asdict(engine_args)) llm = Engine(**dataclasses.asdict(engine_args))
llm.save_remote_model(url=args.remote_model_save_url) llm.save_remote_model(
url=args.remote_model_save_url, draft_url=args.remote_draft_model_save_url
)
print("save remote (draft) model successfully")
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -126,6 +126,14 @@ def get_config( ...@@ -126,6 +126,14 @@ def get_config(
kwargs["gguf_file"] = model kwargs["gguf_file"] = model
model = Path(model).parent model = Path(model).parent
if is_remote_url(model):
# BaseConnector implements __del__() to clean up the local dir.
# Since config files need to exist all the time, so we DO NOT use
# with statement to avoid closing the client.
client = create_remote_connector(model)
client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
model = client.get_local_dir()
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
) )
......
...@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin: ...@@ -121,9 +121,16 @@ class SchedulerUpdateWeightsMixin:
url = params["url"] url = params["url"]
worker = self.tp_worker.worker worker = self.tp_worker.worker
worker.model_runner.save_remote_model(url) worker.model_runner.save_remote_model(url)
if self.draft_worker is not None:
draft_url = params.get("draft_url", None)
assert (
draft_url is not None
), "draft_url must be provided when draft model is enabled"
draft_worker = self.draft_worker.worker
draft_worker.model_runner.save_remote_model(draft_url)
def save_sharded_model(self, params): def save_sharded_model(self, params):
worker = self.tp_worker.worker worker = self.tp_worker.worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment