"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7c59e32d470dfb02a1360179b9cfa955344d4370"
Unverified Commit 6fd93fe9 authored by Jacky Lee's avatar Jacky Lee Committed by GitHub
Browse files

Fix rope theta for OpenLlama (#29893)

fix: rope_theta for open llama
parent 5ad7f170
...@@ -66,6 +66,8 @@ class OpenLlamaConfig(PretrainedConfig): ...@@ -66,6 +66,8 @@ class OpenLlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`. relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`): tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*): rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
...@@ -113,6 +115,7 @@ class OpenLlamaConfig(PretrainedConfig): ...@@ -113,6 +115,7 @@ class OpenLlamaConfig(PretrainedConfig):
attention_dropout_prob=0.1, attention_dropout_prob=0.1,
use_stable_embedding=True, use_stable_embedding=True,
shared_input_output_embedding=True, shared_input_output_embedding=True,
rope_theta=10000.0,
rope_scaling=None, rope_scaling=None,
**kwargs, **kwargs,
): ):
...@@ -133,6 +136,7 @@ class OpenLlamaConfig(PretrainedConfig): ...@@ -133,6 +136,7 @@ class OpenLlamaConfig(PretrainedConfig):
self.attention_dropout_prob = attention_dropout_prob self.attention_dropout_prob = attention_dropout_prob
self.use_stable_embedding = use_stable_embedding self.use_stable_embedding = use_stable_embedding
self.shared_input_output_embedding = shared_input_output_embedding self.shared_input_output_embedding = shared_input_output_embedding
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling self.rope_scaling = rope_scaling
self._rope_scaling_validation() self._rope_scaling_validation()
......
...@@ -214,6 +214,7 @@ class OpenLlamaAttention(nn.Module): ...@@ -214,6 +214,7 @@ class OpenLlamaAttention(nn.Module):
self.head_dim = self.hidden_size // self.num_heads self.head_dim = self.hidden_size // self.num_heads
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.dropout_prob = config.attention_dropout_prob self.dropout_prob = config.attention_dropout_prob
self.rope_theta = config.rope_theta
if (self.head_dim * self.num_heads) != self.hidden_size: if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment