Unverified Commit ebede26e authored by Jie Li's avatar Jie Li Committed by GitHub
Browse files

Make InternLM follow `rope_scaling` in `config.json` (#1956)


Co-authored-by: default avatarlijie8 <lijie8@sensetime.com>
parent d940ce49
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from typing import List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
...@@ -67,6 +67,7 @@ class InternLMAttention(nn.Module): ...@@ -67,6 +67,7 @@ class InternLMAttention(nn.Module):
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
rope_scaling: Optional[Dict[str, Any]] = None,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -99,6 +100,7 @@ class InternLMAttention(nn.Module): ...@@ -99,6 +100,7 @@ class InternLMAttention(nn.Module):
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
max_position=self.max_position_embeddings, max_position=self.max_position_embeddings,
base=self.rope_theta, base=self.rope_theta,
rope_scaling=rope_scaling,
) )
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
...@@ -139,6 +141,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -139,6 +141,7 @@ class InternLMDecoderLayer(nn.Module):
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
linear_method=linear_method, linear_method=linear_method,
rope_scaling=getattr(config, "rope_scaling", None),
) )
self.mlp = InternLMMLP( self.mlp = InternLMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment