Unverified Commit 34d94094 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Llama/GPTNeoX: add RoPE scaling (#24653)



* add rope_scaling

* tmp commit

* add gptneox

* add tests

* GPTNeoX can now handle long inputs, so the pipeline test was wrong

* Update src/transformers/models/open_llama/configuration_open_llama.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* remove ntk

* remove redundant validation

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 9342c8fb
...@@ -78,6 +78,15 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -78,6 +78,15 @@ class GPTNeoXConfig(PretrainedConfig):
use_parallel_residual (`bool`, *optional*, defaults to `True`): use_parallel_residual (`bool`, *optional*, defaults to `True`):
Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training Whether to use a "parallel" formulation in each Transformer layer, which can provide a slight training
speedup at large scales (e.g. 20B). speedup at large scales (e.g. 20B).
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
Example: Example:
```python ```python
...@@ -115,6 +124,7 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -115,6 +124,7 @@ class GPTNeoXConfig(PretrainedConfig):
eos_token_id=2, eos_token_id=2,
tie_word_embeddings=False, tie_word_embeddings=False,
use_parallel_residual=True, use_parallel_residual=True,
rope_scaling=None,
**kwargs, **kwargs,
): ):
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
...@@ -135,7 +145,32 @@ class GPTNeoXConfig(PretrainedConfig): ...@@ -135,7 +145,32 @@ class GPTNeoXConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.tie_word_embeddings = tie_word_embeddings self.tie_word_embeddings = tie_word_embeddings
self.use_parallel_residual = use_parallel_residual self.use_parallel_residual = use_parallel_residual
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
if self.hidden_size % self.num_attention_heads != 0: if self.hidden_size % self.num_attention_heads != 0:
raise ValueError( raise ValueError(
"The hidden size is not divisble by the number of attention heads! Make sure to update them!" "The hidden size is not divisble by the number of attention heads! Make sure to update them!"
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
...@@ -86,6 +86,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): ...@@ -86,6 +86,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
class GPTNeoXAttention(nn.Module): class GPTNeoXAttention(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.config = config
self.num_attention_heads = config.num_attention_heads self.num_attention_heads = config.num_attention_heads
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
if self.hidden_size % self.num_attention_heads != 0: if self.hidden_size % self.num_attention_heads != 0:
...@@ -94,18 +95,11 @@ class GPTNeoXAttention(nn.Module): ...@@ -94,18 +95,11 @@ class GPTNeoXAttention(nn.Module):
) )
self.head_size = self.hidden_size // self.num_attention_heads self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size * config.rotary_pct) self.rotary_ndims = int(self.head_size * config.rotary_pct)
max_positions = config.max_position_embeddings self._init_bias(config.max_position_embeddings)
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False) self.register_buffer("masked_bias", torch.tensor(-1e9), persistent=False)
self.rotary_emb = RotaryEmbedding( self._init_rope()
self.rotary_ndims, config.max_position_embeddings, base=config.rotary_emb_base
)
self.register_buffer( self.register_buffer(
"norm_factor", "norm_factor",
torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()), torch.sqrt(torch.tensor(self.head_size, dtype=torch.float32)).to(torch.get_default_dtype()),
...@@ -113,9 +107,44 @@ class GPTNeoXAttention(nn.Module): ...@@ -113,9 +107,44 @@ class GPTNeoXAttention(nn.Module):
) )
self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size) self.query_key_value = nn.Linear(config.hidden_size, 3 * config.hidden_size)
self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.attention_dropout = nn.Dropout(config.attention_dropout) self.attention_dropout = nn.Dropout(config.attention_dropout)
def _init_bias(self, max_positions, device=None):
self.register_buffer(
"bias",
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
persistent=False,
)
if device is not None:
self.bias = self.bias.to(device)
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = GPTNeoXRotaryEmbedding(
self.rotary_ndims, self.config.max_position_embeddings, base=self.config.rotary_emb_base
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = GPTNeoXLinearScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
elif scaling_type == "dynamic":
self.rotary_emb = GPTNeoXDynamicNTKScalingRotaryEmbedding(
self.rotary_ndims,
self.config.max_position_embeddings,
base=self.config.rotary_emb_base,
scaling_factor=scaling_factor,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward( def forward(
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
...@@ -210,6 +239,9 @@ class GPTNeoXAttention(nn.Module): ...@@ -210,6 +239,9 @@ class GPTNeoXAttention(nn.Module):
batch_size, num_attention_heads, query_length, attn_head_size = query.size() batch_size, num_attention_heads, query_length, attn_head_size = query.size()
key_length = key.size(-2) key_length = key.size(-2)
# dynamically increase the causal mask with the key length, if needed.
if key_length > self.bias.shape[-1]:
self._init_bias(key_length, device=key.device)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
query = query.view(batch_size * num_attention_heads, query_length, attn_head_size) query = query.view(batch_size * num_attention_heads, query_length, attn_head_size)
...@@ -258,15 +290,23 @@ def attention_mask_func(attention_scores, ltor_mask): ...@@ -258,15 +290,23 @@ def attention_mask_func(attention_scores, ltor_mask):
return attention_scores return attention_scores
class RotaryEmbedding(torch.nn.Module): class GPTNeoXRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None): def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
...@@ -275,16 +315,54 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -275,16 +315,54 @@ class RotaryEmbedding(torch.nn.Module):
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached: if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
class GPTNeoXLinearScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype) t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
class GPTNeoXDynamicNTKScalingRotaryEmbedding(GPTNeoXRotaryEmbedding):
"""GPTNeoXRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()[None, None, :, :] self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :] self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
def rotate_half(x): def rotate_half(x):
......
...@@ -238,16 +238,24 @@ class GPTNeoXJapaneseAttention(nn.Module): ...@@ -238,16 +238,24 @@ class GPTNeoXJapaneseAttention(nn.Module):
return attn_output, attn_weights return attn_output, attn_weights
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.RotaryEmbedding # Copied from transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXRotaryEmbedding
class RotaryEmbedding(torch.nn.Module): class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000, device=None): def __init__(self, dim, max_position_embeddings, base=10000, device=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings self._set_cos_sin_cache(seq_len=max_position_embeddings, device=self.inv_freq.device)
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
def _set_cos_sin_cache(self, seq_len, device):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
...@@ -256,15 +264,8 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -256,15 +264,8 @@ class RotaryEmbedding(torch.nn.Module):
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached: if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len self._set_cos_sin_cache(seq_len=seq_len, device=x.device)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, None, :, :]
self.sin_cached = emb.sin()[None, None, :, :]
return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device) return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)
......
...@@ -64,6 +64,15 @@ class LlamaConfig(PretrainedConfig): ...@@ -64,6 +64,15 @@ class LlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`. relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`): tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
Example: Example:
```python ```python
...@@ -97,6 +106,7 @@ class LlamaConfig(PretrainedConfig): ...@@ -97,6 +106,7 @@ class LlamaConfig(PretrainedConfig):
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
tie_word_embeddings=False, tie_word_embeddings=False,
rope_scaling=None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -109,6 +119,9 @@ class LlamaConfig(PretrainedConfig): ...@@ -109,6 +119,9 @@ class LlamaConfig(PretrainedConfig):
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache self.use_cache = use_cache
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
...@@ -116,3 +129,24 @@ class LlamaConfig(PretrainedConfig): ...@@ -116,3 +129,24 @@ class LlamaConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
...@@ -91,36 +91,84 @@ class LlamaRMSNorm(nn.Module): ...@@ -91,36 +91,84 @@ class LlamaRMSNorm(nn.Module):
class LlamaRotaryEmbedding(torch.nn.Module): class LlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings self._set_cos_sin_cache(
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached: if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
) )
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
...@@ -176,7 +224,24 @@ class LlamaAttention(nn.Module): ...@@ -176,7 +224,24 @@ class LlamaAttention(nn.Module):
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
......
...@@ -67,6 +67,15 @@ class OpenLlamaConfig(PretrainedConfig): ...@@ -67,6 +67,15 @@ class OpenLlamaConfig(PretrainedConfig):
relevant if `config.is_decoder=True`. relevant if `config.is_decoder=True`.
tie_word_embeddings(`bool`, *optional*, defaults to `False`): tie_word_embeddings(`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings Whether to tie weight embeddings
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports three scaling
strategies: linear and dynamic. Their scaling factor must be an float greater than 1. The expected format
is `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
these scaling strategies behave:
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
experimental feature, subject to breaking API changes in future versions.
Example: Example:
```python ```python
...@@ -104,6 +113,7 @@ class OpenLlamaConfig(PretrainedConfig): ...@@ -104,6 +113,7 @@ class OpenLlamaConfig(PretrainedConfig):
attention_dropout_prob=0.1, attention_dropout_prob=0.1,
use_stable_embedding=True, use_stable_embedding=True,
shared_input_output_embedding=True, shared_input_output_embedding=True,
rope_scaling=None,
**kwargs, **kwargs,
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -123,6 +133,9 @@ class OpenLlamaConfig(PretrainedConfig): ...@@ -123,6 +133,9 @@ class OpenLlamaConfig(PretrainedConfig):
self.attention_dropout_prob = attention_dropout_prob self.attention_dropout_prob = attention_dropout_prob
self.use_stable_embedding = use_stable_embedding self.use_stable_embedding = use_stable_embedding
self.shared_input_output_embedding = shared_input_output_embedding self.shared_input_output_embedding = shared_input_output_embedding
self.rope_scaling = rope_scaling
self._rope_scaling_validation()
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
...@@ -130,3 +143,25 @@ class OpenLlamaConfig(PretrainedConfig): ...@@ -130,3 +143,25 @@ class OpenLlamaConfig(PretrainedConfig):
tie_word_embeddings=tie_word_embeddings, tie_word_embeddings=tie_word_embeddings,
**kwargs, **kwargs,
) )
# Copied from transformers.models.llama.configuration_llama.LlamaConfig._rope_scaling_validation
def _rope_scaling_validation(self):
"""
Validate the `rope_scaling` configuration.
"""
if self.rope_scaling is None:
return
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
raise ValueError(
"`rope_scaling` must be a dictionary with with two fields, `name` and `factor`, "
f"got {self.rope_scaling}"
)
rope_scaling_type = self.rope_scaling.get("type", None)
rope_scaling_factor = self.rope_scaling.get("factor", None)
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
raise ValueError(
f"`rope_scaling`'s name field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
)
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
raise ValueError(f"`rope_scaling`'s factor field must be an float > 1, got {rope_scaling_factor}")
...@@ -102,36 +102,86 @@ class OpenLlamaRMSNorm(nn.Module): ...@@ -102,36 +102,86 @@ class OpenLlamaRMSNorm(nn.Module):
class OpenLlamaRotaryEmbedding(torch.nn.Module): class OpenLlamaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__() super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
# Build here to make `torch.jit.trace` work. # Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings self._set_cos_sin_cache(
t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype) seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation # Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1) emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x, seq_len=None): def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size] # x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached: if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(x.dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(x.dtype), persistent=False)
return ( return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
) )
# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->OpenLlama
class OpenLlamaLinearScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
"""OpenLlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->OpenLlama
class OpenLlamaDynamicNTKScalingRotaryEmbedding(OpenLlamaRotaryEmbedding):
"""OpenLlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def rotate_half(x): def rotate_half(x):
"""Rotates half the hidden dims of the input.""" """Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2] x1 = x[..., : x.shape[-1] // 2]
...@@ -190,7 +240,27 @@ class OpenLlamaAttention(nn.Module): ...@@ -190,7 +240,27 @@ class OpenLlamaAttention(nn.Module):
self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = OpenLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings) self._init_rope()
# Copied from transformers.models.llama.modeling_llama.LlamaAttention._init_rope with Llama->OpenLlama
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = OpenLlamaRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = OpenLlamaLinearScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
elif scaling_type == "dynamic":
self.rotary_emb = OpenLlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
import unittest import unittest
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available from parameterized import parameterized
from transformers import AutoTokenizer, GPTNeoXConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
...@@ -49,7 +51,7 @@ class GPTNeoXModelTester: ...@@ -49,7 +51,7 @@ class GPTNeoXModelTester:
use_token_type_ids=True, use_token_type_ids=True,
use_labels=True, use_labels=True,
vocab_size=99, vocab_size=99,
hidden_size=32, hidden_size=64,
num_hidden_layers=5, num_hidden_layers=5,
num_attention_heads=4, num_attention_heads=4,
intermediate_size=37, intermediate_size=37,
...@@ -298,6 +300,37 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -298,6 +300,37 @@ class GPTNeoXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
def test_feed_forward_chunking(self): def test_feed_forward_chunking(self):
pass pass
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = GPTNeoXModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = GPTNeoXModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@require_torch @require_torch
class GPTNeoXLanguageGenerationTest(unittest.TestCase): class GPTNeoXLanguageGenerationTest(unittest.TestCase):
......
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
import unittest import unittest
from transformers import LlamaConfig, is_torch_available from parameterized import parameterized
from transformers import LlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, torch_device from transformers.testing_utils import require_torch, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
...@@ -332,3 +334,34 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -332,3 +334,34 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
@unittest.skip("LLaMA buffers include complex numbers, which breaks this test") @unittest.skip("LLaMA buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = LlamaModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = LlamaModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
...@@ -17,7 +17,9 @@ ...@@ -17,7 +17,9 @@
import unittest import unittest
from transformers import OpenLlamaConfig, is_torch_available from parameterized import parameterized
from transformers import OpenLlamaConfig, is_torch_available, set_seed
from transformers.testing_utils import require_torch, torch_device from transformers.testing_utils import require_torch, torch_device
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
...@@ -335,3 +337,34 @@ class OpenLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester ...@@ -335,3 +337,34 @@ class OpenLlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
@unittest.skip("Open-Llama buffers include complex numbers, which breaks this test") @unittest.skip("Open-Llama buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
@parameterized.expand([("linear",), ("dynamic",)])
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)
long_input = ids_tensor([1, int(config.max_position_embeddings * 1.5)], config.vocab_size)
set_seed(42) # Fixed seed at init time so the two models get the same random weights
original_model = OpenLlamaModel(config)
original_model.to(torch_device)
original_model.eval()
original_short_output = original_model(short_input).last_hidden_state
original_long_output = original_model(long_input).last_hidden_state
set_seed(42) # Fixed seed at init time so the two models get the same random weights
config.rope_scaling = {"type": scaling_type, "factor": 10.0}
scaled_model = OpenLlamaModel(config)
scaled_model.to(torch_device)
scaled_model.eval()
scaled_short_output = scaled_model(short_input).last_hidden_state
scaled_long_output = scaled_model(long_input).last_hidden_state
# Dynamic scaling does not change the RoPE embeddings until it receives an input longer than the original
# maximum sequence length, so the outputs for the short input should match.
if scaling_type == "dynamic":
self.assertTrue(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
else:
self.assertFalse(torch.allclose(original_short_output, scaled_short_output, atol=1e-5))
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
...@@ -240,7 +240,7 @@ class TextGenerationPipelineTests(unittest.TestCase): ...@@ -240,7 +240,7 @@ class TextGenerationPipelineTests(unittest.TestCase):
# We don't care about infinite range models. # We don't care about infinite range models.
# They already work. # They already work.
# Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly. # Skip this test for XGLM, since it uses sinusoidal positional embeddings which are resized on-the-fly.
EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM"] EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS = ["RwkvForCausalLM", "XGLMForCausalLM", "GPTNeoXForCausalLM"]
if ( if (
tokenizer.model_max_length < 10000 tokenizer.model_max_length < 10000
and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS and text_generator.model.__class__.__name__ not in EXTRA_MODELS_CAN_HANDLE_LONG_INPUTS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment