Unverified Commit e5c760d6 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[GPTNeoX] Nit in config (#24349)

* add raise value error for attention size

* nits to fix test_config

* style
parent c2882403
...@@ -126,3 +126,7 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -126,3 +126,7 @@ class GPTNeoXConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual self.use_parallel_residual = use_parallel_residual
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!"
)
...@@ -88,6 +88,10 @@ class GPTNeoXAttention(nn.Module): ...@@ -88,6 +88,10 @@ class GPTNeoXAttention(nn.Module):
super().__init__() super().__init__()
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them"
)
self.head_size = self.hidden_size // self.num_attention_heads self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct) self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
......
...@@ -253,7 +253,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -253,7 +253,7 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def setUp(self): def setUp(self):
self.model_tester = GPTNeoXModelTester(self) self.model_tester = GPTNeoXModelTester(self)
self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=37) self.config_tester = ConfigTester(self, config_class=GPTNeoXConfig, hidden_size=64, num_attention_heads=8)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment