"vscode:/vscode.git/clone" did not exist on "2db4469808158700036de79bd41a9c463bb89bdc"
Unverified Commit 3302f0ae authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

rope_theta and max_position_embeddings from config (#1096)


Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarwnma3mz <wnma3mz@gmail.com>
parent 6f2dd6c3
......@@ -57,7 +57,7 @@ class ModelConfig:
load_format: str,
dtype: str,
seed: int,
revision: Optional[str],
revision: Optional[str] = None,
max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
) -> None:
......@@ -73,19 +73,11 @@ class ModelConfig:
self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.max_model_len = _get_and_verify_max_len(self.hf_config,
max_model_len)
self._verify_load_format()
self._verify_tokenizer_mode()
self._verify_quantization()
self.max_model_len = None
if max_model_len is not None:
derived_max_model_len = self.get_max_model_len()
if max_model_len > derived_max_model_len:
logger.warning(
f"User-specified max_model_len ({max_model_len}) is "
f"greater than the derived max_model_len "
f"({derived_max_model_len}). Make sure the value is "
"correct and within the model context size.")
self.max_model_len = max_model_len
def _verify_load_format(self) -> None:
load_format = self.load_format.lower()
......@@ -168,26 +160,7 @@ class ModelConfig:
return total_num_attention_heads // parallel_config.tensor_parallel_size
def get_max_model_len(self) -> int:
if self.max_model_len is not None:
return self.max_model_len
max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(self.hf_config, key, None)
if max_len_key is not None:
max_model_len = min(max_model_len, max_len_key)
return max_model_len
return self.max_model_len
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
total_num_hidden_layers = self.hf_config.num_hidden_layers
......@@ -348,3 +321,38 @@ def _get_and_verify_dtype(
f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.")
return torch_dtype
def _get_and_verify_max_len(
hf_config: PretrainedConfig,
max_model_len: Optional[int],
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
possible_keys = [
# OPT
"max_position_embeddings",
# GPT-2
"n_positions",
# MPT
"max_seq_len",
# Others
"max_sequence_length",
"max_seq_length",
"seq_len",
]
for key in possible_keys:
max_len_key = getattr(hf_config, key, None)
if max_len_key is not None:
derived_max_model_len = min(derived_max_model_len, max_len_key)
if max_model_len is None:
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
raise ValueError(
f"User-specified max_model_len ({max_model_len}) is greater than "
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
" in model's config.json). This may lead to incorrect model "
"outputs or CUDA errors. Make sure the value is correct and "
"within the model context size.")
return max_model_len
......@@ -105,6 +105,8 @@ class AquilaAttention(nn.Module):
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -119,6 +121,8 @@ class AquilaAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear(
hidden_size,
......@@ -140,6 +144,8 @@ class AquilaAttention(nn.Module):
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings,
)
def forward(
......@@ -164,10 +170,15 @@ class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = AquilaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_attention_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
)
self.mlp = AquilaMLP(
hidden_size=self.hidden_size,
......
......@@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module):
hidden_size: int,
num_heads: int,
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module):
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear(
......@@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module):
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings)
def forward(
self,
......@@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,
......
......@@ -161,12 +161,17 @@ class FalconAttention(nn.Module):
"Rotary and alibi are mutually exclusive.")
if self.use_rotary:
# TODO(zhuohan): Pass in correct `max_position``
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.inv_norm_factor,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config,
"max_position_embeddings", 8192)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.inv_norm_factor,
base=rope_theta,
max_position=max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
......
......@@ -67,11 +67,17 @@ class GPTJAttention(nn.Module):
scaling = self.head_size**-0.5
assert getattr(config, "rotary", True)
assert config.rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_size,
scaling,
config.rotary_dim,
is_neox_style=False)
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_size,
scaling,
config.rotary_dim,
base=rope_theta,
max_position=max_position_embeddings,
is_neox_style=False)
self.warmup = False
def forward(
......
......@@ -68,8 +68,16 @@ class GPTNeoXAttention(nn.Module):
scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
scaling, rotary_dim)
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_size,
scaling,
rotary_dim,
base=rope_theta,
max_position=max_position_embeddings)
def forward(
self,
......
......@@ -59,6 +59,8 @@ class InternLMAttention(nn.Module):
self,
hidden_size: int,
num_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
):
super().__init__()
self.hidden_size = hidden_size
......@@ -70,6 +72,8 @@ class InternLMAttention(nn.Module):
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ColumnParallelLinear(
hidden_size,
......@@ -85,10 +89,13 @@ class InternLMAttention(nn.Module):
input_is_parallel=True,
perform_initialization=False,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim)
def forward(
self,
......@@ -112,9 +119,14 @@ class InternLMDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = InternLMAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
)
self.mlp = InternLMMLP(
hidden_size=self.hidden_size,
......
......@@ -92,6 +92,7 @@ class LlamaAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
......@@ -108,6 +109,7 @@ class LlamaAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = ParallelLinear.column(
hidden_size,
......@@ -126,12 +128,14 @@ class LlamaAttention(nn.Module):
perform_initialization=False,
quant_config=quant_config,
)
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
base=self.rope_theta,
max_position=self.max_position_embeddings,
rotary_dim=self.head_dim,
num_kv_heads=self.num_kv_heads)
def forward(
self,
......@@ -161,11 +165,14 @@ class LlamaDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = LlamaAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
)
self.mlp = LlamaMLP(
......
......@@ -76,8 +76,13 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int,
max_position_embeddings: int):
def __init__(
self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
......@@ -109,6 +114,7 @@ class QWenAttention(nn.Module):
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
base=rope_theta,
max_position=max_position_embeddings,
)
......@@ -137,8 +143,11 @@ class QWenBlock(nn.Module):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
config.max_position_embeddings)
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = QWenAttention(config.n_embd,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment