Unverified Commit 13c62be4 authored by Dean Wyatte's avatar Dean Wyatte Committed by GitHub
Browse files

GPTNeoX: Use static rotary embedding (#1498)

# What does this PR do?

`transformers` 4.35 removed rotary embeddings from GPTNeoX's weights
([link to line
diff](https://github.com/huggingface/transformers/commit/253f9a3f9716d08a81fb305fe71f983122eb608b#diff-0e2a05d86c82e96f516db8c14070ceb36f53ca44c6bc21a9cd92ad2e777b9cf1R298)).
This applies the same fix as
https://github.com/huggingface/text-generation-inference/pull/793 which
generates them on-the-fly using the appropriate value from the config
file

Fixes
https://github.com/huggingface/text-generation-inference/issues/1460

## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [x] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation).
- [ ] Did you write any new necessary tests?


## Who can review?

@OlivierDehaene OR @Narsil
parent 2ae36a97
...@@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -91,6 +91,8 @@ class FlashNeoxAttention(torch.nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.head_size = hidden_size // num_heads self.head_size = hidden_size // num_heads
self.rotary_dim = int(config.rotary_pct * self.head_size)
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
raise ValueError( raise ValueError(
f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} " f"`num_heads` must be divisible by `num_shards` (got `num_heads`: {self.num_heads} "
...@@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module): ...@@ -98,8 +100,11 @@ class FlashNeoxAttention(torch.nn.Module):
) )
self.num_heads = self.num_heads // weights.process_group.size() self.num_heads = self.num_heads // weights.process_group.size()
self.rotary_emb = PositionRotaryEmbedding.load( self.rotary_emb = PositionRotaryEmbedding.static(
config=config, prefix=f"{prefix}.rotary_emb", weights=weights config=config,
dim=self.rotary_dim,
base=config.rotary_emb_base,
device=weights.device,
) )
self.softmax_scale = self.head_size ** (-0.5) self.softmax_scale = self.head_size ** (-0.5)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment