Unverified Commit 28b47d1e authored by Qing's avatar Qing Committed by GitHub
Browse files

Add rope_scaling to Aquila model (#1457)

parent 1f24755b
...@@ -25,7 +25,7 @@ ...@@ -25,7 +25,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input. InputMetadata to extract the original 2D shape of the input.
""" """
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -110,6 +110,7 @@ class AquilaAttention(nn.Module): ...@@ -110,6 +110,7 @@ class AquilaAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
rope_scaling: Optional[Dict[str, Any]] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -148,6 +149,7 @@ class AquilaAttention(nn.Module): ...@@ -148,6 +149,7 @@ class AquilaAttention(nn.Module):
base=self.rope_theta, base=self.rope_theta,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
rope_scaling=rope_scaling,
) )
def forward( def forward(
...@@ -173,6 +175,7 @@ class AquilaDecoderLayer(nn.Module): ...@@ -173,6 +175,7 @@ class AquilaDecoderLayer(nn.Module):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", max_position_embeddings = getattr(config, "max_position_embeddings",
8192) 8192)
self.self_attn = AquilaAttention( self.self_attn = AquilaAttention(
...@@ -181,6 +184,7 @@ class AquilaDecoderLayer(nn.Module): ...@@ -181,6 +184,7 @@ class AquilaDecoderLayer(nn.Module):
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
rope_scaling=rope_scaling,
) )
self.mlp = AquilaMLP( self.mlp = AquilaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment