Unverified Commit ed8cbfed authored by Thomas Parnell's avatar Thomas Parnell Committed by GitHub
Browse files

Let GraniteMoeAttention use YaRN (#21174)


Signed-off-by: default avatarThomas Parnell <tpa@zurich.ibm.com>
parent 45badd05
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
# limitations under the License. # limitations under the License.
"""Inference-only GraniteMoe model.""" """Inference-only GraniteMoe model."""
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional from typing import Any, Optional
import torch import torch
from torch import nn from torch import nn
...@@ -113,6 +113,7 @@ class GraniteMoeAttention(nn.Module): ...@@ -113,6 +113,7 @@ class GraniteMoeAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
max_position: int = 4096 * 32, max_position: int = 4096 * 32,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
cache_config: Optional[CacheConfig] = None, cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
attention_multiplier: Optional[float] = None, attention_multiplier: Optional[float] = None,
...@@ -163,6 +164,7 @@ class GraniteMoeAttention(nn.Module): ...@@ -163,6 +164,7 @@ class GraniteMoeAttention(nn.Module):
max_position=max_position, max_position=max_position,
base=int(self.rope_theta), base=int(self.rope_theta),
is_neox_style=True, is_neox_style=True,
rope_scaling=rope_scaling,
) )
self.attn = Attention(self.num_heads, self.attn = Attention(self.num_heads,
self.head_dim, self.head_dim,
...@@ -198,12 +200,14 @@ class GraniteMoeDecoderLayer(nn.Module): ...@@ -198,12 +200,14 @@ class GraniteMoeDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = GraniteMoeAttention( self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
......
...@@ -81,12 +81,14 @@ class GraniteMoeSharedDecoderLayer(nn.Module): ...@@ -81,12 +81,14 @@ class GraniteMoeSharedDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.self_attn = GraniteMoeAttention( self.self_attn = GraniteMoeAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings, max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling,
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment