Unverified Commit ffb32a85 authored by harrisonlimh's avatar harrisonlimh Committed by GitHub
Browse files

Conditionally recapture cuda graph after model weight update from disk (#12060)

parent 14d80648
......@@ -1024,6 +1024,8 @@ class UpdateWeightFromDiskReqInput(BaseReq):
torch_empty_cache: bool = False
# Whether to keep the scheduler paused after weight update
keep_pause: bool = False
# Whether to recapture cuda graph after weight udpdate
recapture_cuda_graph: bool = False
# The trainer step id. Used to know which step's weights are used for sampling.
token_step: int = 0
......
......@@ -100,7 +100,7 @@ class BaseTpWorker(ABC):
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
success, message = self.model_runner.update_weights_from_disk(
recv_req.model_path, recv_req.load_format
recv_req.model_path, recv_req.load_format, recv_req.recapture_cuda_graph
)
return success, message
......
......@@ -862,6 +862,7 @@ class ModelRunner:
model_path: str,
load_format: str,
weight_name_filter: Optional[Callable[[str], bool]] = None,
recapture_cuda_graph: bool = False,
) -> tuple[bool, str]:
"""Update engine weights in-place from the disk."""
logger.info(
......@@ -917,6 +918,9 @@ class ModelRunner:
self.server_args.load_format = load_format
self.load_config = load_config
if recapture_cuda_graph and self.device == "cuda":
self.init_device_graphs()
logger.info("Update weights end.")
return True, "Succeeded to update model weights."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment