Unverified Commit c016c95b authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Use helper function instead of looping through attribute names (#29788)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 1339878e
......@@ -1094,11 +1094,10 @@ class ModelConfig:
# The size of inputs_embeds is usually identical to the size
# of the hidden states, however there are exceptions, such as
# embedding models like CLIP and SigLIP
for target_attr in ("projection_dim", "projection_size"):
if hasattr(self.hf_text_config, target_attr):
return getattr(self.hf_text_config, target_attr)
return self.get_hidden_size()
names = ("projection_dim", "projection_size")
return getattr_iter(
self.hf_text_config, names, default_factory=self.get_hidden_size
)
@property
def is_deepseek_mla(self) -> bool:
......@@ -1231,14 +1230,12 @@ class ModelConfig:
# For ChatGLM:
"multi_query_group_num",
]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads.
return self.hf_text_config.num_attention_heads
default_factory = lambda: self.hf_text_config.num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
"""Returns the number of KV heads per GPU."""
......
......@@ -9,7 +9,7 @@ import inspect
import json
import pathlib
import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set
from collections.abc import Callable, Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
......@@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False
object: object,
names: Iterable[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,
) -> Any:
"""
A helper function that retrieves an attribute from an object which may
......@@ -96,7 +100,7 @@ def getattr_iter(
names[0],
)
return getattr(object, name)
return default
return default_factory() if default_factory is not None else default
def contains_object_print(text: str) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment