Unverified Commit 3302f0ae authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

rope_theta and max_position_embeddings from config (#1096)


Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarwnma3mz <wnma3mz@gmail.com>
parent 6f2dd6c3
...@@ -57,7 +57,7 @@ class ModelConfig: ...@@ -57,7 +57,7 @@ class ModelConfig:
load_format: str, load_format: str,
dtype: str, dtype: str,
seed: int, seed: int,
revision: Optional[str], revision: Optional[str] = None,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
quantization: Optional[str] = None, quantization: Optional[str] = None,
) -> None: ) -> None:
...@@ -73,19 +73,11 @@ class ModelConfig: ...@@ -73,19 +73,11 @@ class ModelConfig:
self.hf_config = get_config(model, trust_remote_code, revision) self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
self._verify_quantization() self._verify_quantization()
self.max_model_len = None
if max_model_len is not None:
derived_max_model_len = self.get_max_model_len()
if max_model_len > derived_max_model_len:
logger.warning(
f"User-specified max_model_len ({max_model_len}) is "
f"greater than the derived max_model_len "
f"({derived_max_model_len}). Make sure the value is "
"correct and within the model context size.")
self.max_model_len = max_model_len
def _verify_load_format(self) -> None: def _verify_load_format(self) -> None:
load_format = self.load_format.lower() load_format = self.load_format.lower()
...@@ -168,26 +160,7 @@ class ModelConfig: ...@@ -168,26 +160,7 @@ class ModelConfig:
return total_num_attention_heads // parallel_config.tensor_parallel_size return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int: def get_max_model_len(self) -> int:
if self.max_model_len is not None: return self.max_model_len
return self.max_model_len
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
return max_model_len
def get_num_layers(self, parallel_config: "ParallelConfig") -> int: def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers total_num_hidden_layers = self.hf_config.num_hidden_layers
...@@ -348,3 +321,38 @@ def _get_and_verify_dtype( ...@@ -348,3 +321,38 @@ def _get_and_verify_dtype(
f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.") f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype return torch_dtype
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
return max_model_len
...@@ -105,6 +105,8 @@ class AquilaAttention(nn.Module): ...@@ -105,6 +105,8 @@ class AquilaAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -119,6 +121,8 @@ class AquilaAttention(nn.Module): ...@@ -119,6 +121,8 @@ class AquilaAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
...@@ -140,6 +144,8 @@ class AquilaAttention(nn.Module): ...@@ -140,6 +144,8 @@ class AquilaAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings,
) )
def forward( def forward(
...@@ -164,10 +170,15 @@ class AquilaDecoderLayer(nn.Module): ...@@ -164,10 +170,15 @@ class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig): def __init__(self, config: AquilaConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = AquilaAttention( self.self_attn = AquilaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_attention_heads, num_kv_heads=config.num_attention_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
) )
self.mlp = AquilaMLP( self.mlp = AquilaMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
...@@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module): ...@@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module):
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
position_embedding: str, position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module): ...@@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module):
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding self.postion_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name # pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear( self.W_pack = ColumnParallelLinear(
...@@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module): ...@@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module):
scaling, alibi_slopes) scaling, alibi_slopes)
else: else:
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
rotary_dim=self.head_dim) self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings)
def forward( def forward(
self, self,
...@@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module): ...@@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str): def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = BaiChuanAttention( self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
position_embedding=position_embedding, position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
) )
self.mlp = BaiChuanMLP( self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
...@@ -161,12 +161,17 @@ class FalconAttention(nn.Module): ...@@ -161,12 +161,17 @@ class FalconAttention(nn.Module):
"Rotary and alibi are mutually exclusive.") "Rotary and alibi are mutually exclusive.")
if self.use_rotary: if self.use_rotary:
# TODO(zhuohan): Pass in correct `max_position`` rope_theta = getattr(config, "rope_theta", 10000)
self.attn = PagedAttentionWithRoPE(self.num_heads, max_position_embeddings = getattr(config,
self.head_dim, "max_position_embeddings", 8192)
self.inv_norm_factor, self.attn = PagedAttentionWithRoPE(
rotary_dim=self.head_dim, self.num_heads,
num_kv_heads=self.num_kv_heads) self.head_dim,
self.inv_norm_factor,
base=rope_theta,
max_position=max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi: elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads head_start = tp_rank * self.num_heads
......
...@@ -67,11 +67,17 @@ class GPTJAttention(nn.Module): ...@@ -67,11 +67,17 @@ class GPTJAttention(nn.Module):
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
assert getattr(config, "rotary", True) assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0 assert config.rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, rope_theta = getattr(config, "rope_theta", 10000)
self.head_size, max_position_embeddings = getattr(config, "max_position_embeddings",
scaling, 8192)
config.rotary_dim, self.attn = PagedAttentionWithRoPE(
is_neox_style=False) self.num_heads,
self.head_size,
scaling,
config.rotary_dim,
base=rope_theta,
max_position=max_position_embeddings,
is_neox_style=False)
self.warmup = False self.warmup = False
def forward( def forward(
......
...@@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module): ...@@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module):
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size, rope_theta = getattr(config, "rope_theta", 10000)
scaling, rotary_dim) max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_size,
scaling,
rotary_dim,
base=rope_theta,
max_position=max_position_embeddings)
def forward( def forward(
self, self,
......
...@@ -59,6 +59,8 @@ class InternLMAttention(nn.Module): ...@@ -59,6 +59,8 @@ class InternLMAttention(nn.Module):
self, self,
hidden_size: int, hidden_size: int,
num_heads: int, num_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -70,6 +72,8 @@ class InternLMAttention(nn.Module): ...@@ -70,6 +72,8 @@ class InternLMAttention(nn.Module):
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear( self.qkv_proj = ColumnParallelLinear(
hidden_size, hidden_size,
...@@ -85,10 +89,13 @@ class InternLMAttention(nn.Module): ...@@ -85,10 +89,13 @@ class InternLMAttention(nn.Module):
input_is_parallel=True, input_is_parallel=True,
perform_initialization=False, perform_initialization=False,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
rotary_dim=self.head_dim) self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim)
def forward( def forward(
self, self,
...@@ -112,9 +119,14 @@ class InternLMDecoderLayer(nn.Module): ...@@ -112,9 +119,14 @@ class InternLMDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig): def __init__(self, config: LlamaConfig):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = InternLMAttention( self.self_attn = InternLMAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
) )
self.mlp = InternLMMLP( self.mlp = InternLMMLP(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
......
...@@ -92,6 +92,7 @@ class LlamaAttention(nn.Module): ...@@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
num_heads: int, num_heads: int,
num_kv_heads: int, num_kv_heads: int,
rope_theta: float = 10000, rope_theta: float = 10000,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -108,6 +109,7 @@ class LlamaAttention(nn.Module): ...@@ -108,6 +109,7 @@ class LlamaAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ParallelLinear.column( self.qkv_proj = ParallelLinear.column(
hidden_size, hidden_size,
...@@ -126,12 +128,14 @@ class LlamaAttention(nn.Module): ...@@ -126,12 +128,14 @@ class LlamaAttention(nn.Module):
perform_initialization=False, perform_initialization=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.attn = PagedAttentionWithRoPE(self.num_heads, self.attn = PagedAttentionWithRoPE(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
base=self.rope_theta, self.scaling,
rotary_dim=self.head_dim, base=self.rope_theta,
num_kv_heads=self.num_kv_heads) max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
def forward( def forward(
self, self,
...@@ -161,11 +165,14 @@ class LlamaDecoderLayer(nn.Module): ...@@ -161,11 +165,14 @@ class LlamaDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0 # Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads, num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta, rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
......
...@@ -76,8 +76,13 @@ class QWenMLP(nn.Module): ...@@ -76,8 +76,13 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module): class QWenAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, def __init__(
max_position_embeddings: int): self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
...@@ -109,6 +114,7 @@ class QWenAttention(nn.Module): ...@@ -109,6 +114,7 @@ class QWenAttention(nn.Module):
self.head_dim, self.head_dim,
self.scaling, self.scaling,
rotary_dim=self.head_dim, rotary_dim=self.head_dim,
base=rope_theta,
max_position=max_position_embeddings, max_position=max_position_embeddings,
) )
...@@ -137,8 +143,11 @@ class QWenBlock(nn.Module): ...@@ -137,8 +143,11 @@ class QWenBlock(nn.Module):
super().__init__() super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = QWenAttention(config.n_embd, config.num_attention_heads, rope_theta = getattr(config, "rope_theta", 10000)
config.max_position_embeddings) self.attn = QWenAttention(config.n_embd,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon) self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment