Unverified Commit 82fbeae9 authored by Jun Duan's avatar Jun Duan Committed by GitHub
Browse files

[Misc] Accurately capture the time of loading weights (#14063)


Signed-off-by: default avatarJun Duan <jun.duan.phd@outlook.com>
parent cc5e8f6d
...@@ -10,6 +10,7 @@ import inspect ...@@ -10,6 +10,7 @@ import inspect
import itertools import itertools
import math import math
import os import os
import time
import warnings import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
...@@ -216,6 +217,9 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -216,6 +217,9 @@ class DefaultModelLoader(BaseModelLoader):
allow_patterns_overrides: Optional[list[str]] = None allow_patterns_overrides: Optional[list[str]] = None
"""If defined, weights will load exclusively using these patterns.""" """If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig): def __init__(self, load_config: LoadConfig):
super().__init__(load_config) super().__init__(load_config)
if load_config.model_loader_extra_config: if load_config.model_loader_extra_config:
...@@ -368,6 +372,8 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -368,6 +372,8 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = _xla_weights_iterator(weights_iterator) weights_iterator = _xla_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix. # Apply the prefix.
return ((source.prefix + name, tensor) return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator) for (name, tensor) in weights_iterator)
...@@ -412,6 +418,11 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -412,6 +418,11 @@ class DefaultModelLoader(BaseModelLoader):
weights_to_load = {name for name, _ in model.named_parameters()} weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights( loaded_weights = model.load_weights(
self._get_all_weights(model_config, model)) self._get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models # We only enable strict check for non-quantized models
# that have loaded weights tracking currently. # that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None: if model_config.quantization is None and loaded_weights is not None:
......
...@@ -1061,7 +1061,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1061,7 +1061,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.device) self.device)
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB and %.6f seconds", logger.info("Model loading took %.4f GB and %.6f seconds",
self.model_memory_usage / float(2**30), self.model_memory_usage / float(2**30),
time_after_load - time_before_load) time_after_load - time_before_load)
......
...@@ -1114,7 +1114,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1114,7 +1114,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
time_after_load = time.perf_counter() time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory self.model_memory_usage = m.consumed_memory
logger.info("Loading model weights took %.4f GB and %.6f seconds", logger.info("Model loading took %.4f GB and %.6f seconds",
self.model_memory_usage / float(2**30), self.model_memory_usage / float(2**30),
time_after_load - time_before_load) time_after_load - time_before_load)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment