"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "14a0d21d2ea2809ca9f88958edc459b5a1c81a16"
Unverified Commit ffb32a85 authored by harrisonlimh's avatar harrisonlimh Committed by GitHub
Browse files

Conditionally recapture cuda graph after model weight update from disk (#12060)

parent 14d80648
...@@ -1024,6 +1024,8 @@ class UpdateWeightFromDiskReqInput(BaseReq): ...@@ -1024,6 +1024,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
torch_empty_cache: bool = False torch_empty_cache: bool = False
# Whether to keep the scheduler paused after weight update # Whether to keep the scheduler paused after weight update
keep_pause: bool = False keep_pause: bool = False
# Whether to recapture cuda graph after weight udpdate
recapture_cuda_graph: bool = False
# The trainer step id. Used to know which step's weights are used for sampling. # The trainer step id. Used to know which step's weights are used for sampling.
token_step: int = 0 token_step: int = 0
......
...@@ -100,7 +100,7 @@ class BaseTpWorker(ABC): ...@@ -100,7 +100,7 @@ class BaseTpWorker(ABC):
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput): def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk( success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format recv_req.model_path, recv_req.load_format, recv_req.recapture_cuda_graph
) )
return success, message return success, message
......
...@@ -862,6 +862,7 @@ class ModelRunner: ...@@ -862,6 +862,7 @@ class ModelRunner:
model_path: str, model_path: str,
load_format: str, load_format: str,
weight_name_filter: Optional[Callable[[str], bool]] = None, weight_name_filter: Optional[Callable[[str], bool]] = None,
recapture_cuda_graph: bool = False,
) -> tuple[bool, str]: ) -> tuple[bool, str]:
"""Update engine weights in-place from the disk.""" """Update engine weights in-place from the disk."""
logger.info( logger.info(
...@@ -917,6 +918,9 @@ class ModelRunner: ...@@ -917,6 +918,9 @@ class ModelRunner:
self.server_args.load_format = load_format self.server_args.load_format = load_format
self.load_config = load_config self.load_config = load_config
if recapture_cuda_graph and self.device == "cuda":
self.init_device_graphs()
logger.info("Update weights end.") logger.info("Update weights end.")
return True, "Succeeded to update model weights." return True, "Succeeded to update model weights."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment