Unverified Commit 53a1ba6e authored by Ning Xie's avatar Ning Xie Committed by GitHub
Browse files

[log] add weights loading time log to sharded_state loader (#28628)


Signed-off-by: default avatarAndy Xie <andy.xning@gmail.com>
parent 1840c5cb
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
import collections import collections
import glob import glob
import os import os
import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
...@@ -132,6 +133,7 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -132,6 +133,7 @@ class ShardedStateLoader(BaseModelLoader):
f"pre-sharded checkpoints are currently supported!" f"pre-sharded checkpoints are currently supported!"
) )
state_dict = self._filter_subtensors(model.state_dict()) state_dict = self._filter_subtensors(model.state_dict())
counter_before_loading_weights = time.perf_counter()
for key, tensor in self.iterate_over_files(filepaths): for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may # If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a # be added to certain parameters. We only load into a
...@@ -150,6 +152,12 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -150,6 +152,12 @@ class ShardedStateLoader(BaseModelLoader):
) )
param_data.copy_(tensor) param_data.copy_(tensor)
state_dict.pop(key) state_dict.pop(key)
counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
counter_after_loading_weights - counter_before_loading_weights,
scope="local",
)
if state_dict: if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!") raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment