Unverified Commit 6d68030f authored by Naveassaf's avatar Naveassaf Committed by GitHub
Browse files

[Model] Add support for YARN in NemotronNAS models (#18427)


Signed-off-by: default avatarNave Assaf <nassaf@nvidia.com>
parent 5a2c76cb
...@@ -162,20 +162,9 @@ class LlamaAttention(nn.Module): ...@@ -162,20 +162,9 @@ class LlamaAttention(nn.Module):
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
) )
is_neox_style = True self._init_rotary_emb(config,
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=is_neox_style, quant_config=quant_config)
partial_rotary_factor=self.partial_rotary_factor,
)
if hasattr(config, "interleaved_sliding_window"): if hasattr(config, "interleaved_sliding_window"):
interleaved_sliding_window = config.interleaved_sliding_window interleaved_sliding_window = config.interleaved_sliding_window
...@@ -214,6 +203,24 @@ class LlamaAttention(nn.Module): ...@@ -214,6 +203,24 @@ class LlamaAttention(nn.Module):
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
def _init_rotary_emb(self, config: LlamaConfig,
rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
is_neox_style = True
is_gguf = quant_config and quant_config.get_name() == "gguf"
if is_gguf and self.config.model_type == "llama":
is_neox_style = False
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor,
)
class LlamaDecoderLayer(nn.Module): class LlamaDecoderLayer(nn.Module):
......
...@@ -23,18 +23,20 @@ ...@@ -23,18 +23,20 @@
# limitations under the License. # limitations under the License.
"""Inference-only deci model compatible with HuggingFace weights.""" """Inference-only deci model compatible with HuggingFace weights."""
from collections.abc import Iterable from collections.abc import Iterable
from typing import Optional, Union from typing import Any, Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.attention import AttentionType
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
...@@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int: ...@@ -62,6 +64,48 @@ def _find_multiple(n: int, k: int) -> int:
return n + k - (n % k) return n + k - (n % k)
class DeciLMAttention(LlamaAttention):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
bias_o_proj: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
) -> None:
super().__init__(config, hidden_size, num_heads, num_kv_heads,
rope_theta, rope_scaling, max_position_embeddings,
quant_config, bias, bias_o_proj, cache_config, prefix,
attn_type)
def _init_rotary_emb(self, config, rope_scaling: Optional[dict[str, Any]],
quant_config: Optional[QuantizationConfig]) -> None:
# Enables YARN for Mistral and LLaMA4 derivatives.
is_neox_style = True
if hasattr(config, "position_embedding_type"):
is_neox_style = config.position_embedding_type not in [
"mistral_yarn", "rope_llama4"
]
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=self.max_position_embeddings,
base=self.rope_theta,
rope_scaling=rope_scaling,
is_neox_style=is_neox_style,
partial_rotary_factor=self.partial_rotary_factor)
class DeciLMDecoderLayer(nn.Module): class DeciLMDecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -98,7 +142,7 @@ class DeciLMDecoderLayer(nn.Module): ...@@ -98,7 +142,7 @@ class DeciLMDecoderLayer(nn.Module):
if not self._is_no_op_attention: if not self._is_no_op_attention:
num_kv_heads = (config.num_attention_heads // num_kv_heads = (config.num_attention_heads //
block_config.attention.n_heads_in_group) block_config.attention.n_heads_in_group)
self.self_attn = LlamaAttention( self.self_attn = DeciLMAttention(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
num_heads=config.num_attention_heads, num_heads=config.num_attention_heads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment