Unverified Commit 53a1ba6e authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[log] add weights loading time log to sharded_state loader (#28628)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent 1840c5cb
......@@ -4,6 +4,7 @@
import collections
import glob
import os
import time
from collections.abc import Generator
from typing import Any
......@@ -132,6 +133,7 @@ class ShardedStateLoader(BaseModelLoader):
f"pre-sharded checkpoints are currently supported!"
)
state_dict = self._filter_subtensors(model.state_dict())
counter_before_loading_weights = time.perf_counter()
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
......@@ -150,6 +152,12 @@ class ShardedStateLoader(BaseModelLoader):
)
param_data.copy_(tensor)
state_dict.pop(key)
counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
counter_after_loading_weights - counter_before_loading_weights,
scope="local",
)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment