Unverified Commit c016c95b authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Use helper function instead of looping through attribute names (#29788)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 1339878e
...@@ -1094,11 +1094,10 @@ class ModelConfig: ...@@ -1094,11 +1094,10 @@ class ModelConfig:
# The size of inputs_embeds is usually identical to the size # The size of inputs_embeds is usually identical to the size
# of the hidden states, however there are exceptions, such as # of the hidden states, however there are exceptions, such as
# embedding models like CLIP and SigLIP # embedding models like CLIP and SigLIP
for target_attr in ("projection_dim", "projection_size"): names = ("projection_dim", "projection_size")
if hasattr(self.hf_text_config, target_attr): return getattr_iter(
return getattr(self.hf_text_config, target_attr) self.hf_text_config, names, default_factory=self.get_hidden_size
)
return self.get_hidden_size()
@property @property
def is_deepseek_mla(self) -> bool: def is_deepseek_mla(self) -> bool:
...@@ -1231,14 +1230,12 @@ class ModelConfig: ...@@ -1231,14 +1230,12 @@ class ModelConfig:
# For ChatGLM: # For ChatGLM:
"multi_query_group_num", "multi_query_group_num",
] ]
for attr in attributes:
num_kv_heads = getattr(self.hf_text_config, attr, None)
if num_kv_heads is not None:
return num_kv_heads
# For non-grouped-query attention models, the number of KV heads is # For non-grouped-query attention models, the number of KV heads is
# equal to the number of attention heads. # equal to the number of attention heads.
return self.hf_text_config.num_attention_heads default_factory = lambda: self.hf_text_config.num_attention_heads
return getattr_iter(
self.hf_text_config, attributes, default_factory=default_factory
)
def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int: def get_num_kv_heads(self, parallel_config: ParallelConfig) -> int:
"""Returns the number of KV heads per GPU.""" """Returns the number of KV heads per GPU."""
......
...@@ -9,7 +9,7 @@ import inspect ...@@ -9,7 +9,7 @@ import inspect
import json import json
import pathlib import pathlib
import textwrap import textwrap
from collections.abc import Iterable, Mapping, Sequence, Set from collections.abc import Callable, Iterable, Mapping, Sequence, Set
from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace from dataclasses import MISSING, Field, dataclass, field, fields, is_dataclass, replace
from itertools import pairwise from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar from typing import TYPE_CHECKING, Any, Protocol, TypeVar
...@@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field: ...@@ -74,7 +74,11 @@ def get_field(cls: ConfigType, name: str) -> Field:
def getattr_iter( def getattr_iter(
object: object, names: Iterable[str], default: Any, warn: bool = False object: object,
names: Iterable[str],
default: Any | None = None,
default_factory: Callable[[], Any] | None = None,
warn: bool = False,
) -> Any: ) -> Any:
""" """
A helper function that retrieves an attribute from an object which may A helper function that retrieves an attribute from an object which may
...@@ -96,7 +100,7 @@ def getattr_iter( ...@@ -96,7 +100,7 @@ def getattr_iter(
names[0], names[0],
) )
return getattr(object, name) return getattr(object, name)
return default return default_factory() if default_factory is not None else default
def contains_object_print(text: str) -> bool: def contains_object_print(text: str) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment