Unverified Commit 89fa54e6 authored by Cyrus Leung's avatar Cyrus Leung Committed by GitHub
Browse files

[Optimization] Use a cheaper cache key in `get_model_architecture` (#25682)


Signed-off-by: default avatarDarkLight1337 <tlleungac@connect.ust.hk>
parent 3d54bdcb
...@@ -165,7 +165,7 @@ def device_loading_context(module: torch.nn.Module, ...@@ -165,7 +165,7 @@ def device_loading_context(module: torch.nn.Module,
# New parameters or parameters already on target device are untouched # New parameters or parameters already on target device are untouched
_MODEL_ARCH_BY_HASH = dict[str, tuple[type[nn.Module], str]]() _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
"""Caches the outputs of `_get_model_architecture`.""" """Caches the outputs of `_get_model_architecture`."""
...@@ -215,7 +215,14 @@ def _get_model_architecture( ...@@ -215,7 +215,14 @@ def _get_model_architecture(
def get_model_architecture( def get_model_architecture(
model_config: ModelConfig) -> tuple[type[nn.Module], str]: model_config: ModelConfig) -> tuple[type[nn.Module], str]:
key = model_config.compute_hash() key = hash((
model_config.model,
model_config.convert_type,
model_config.runner_type,
model_config.trust_remote_code,
model_config.model_impl,
tuple(getattr(model_config.hf_config, "architectures", [])),
))
if key in _MODEL_ARCH_BY_HASH: if key in _MODEL_ARCH_BY_HASH:
return _MODEL_ARCH_BY_HASH[key] return _MODEL_ARCH_BY_HASH[key]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment