Commit 441cca77 authored by Chen Xuechen Li's avatar Chen Xuechen Li Committed by Ying Sheng
Browse files

support gptj style rope in llama

parent c7709d3a
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
...@@ -76,6 +77,7 @@ class LlamaAttention(nn.Module): ...@@ -76,6 +77,7 @@ class LlamaAttention(nn.Module):
layer_id: int = 0, layer_id: int = 0,
rope_theta: float = 10000, rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
rope_is_neox_style: bool = True,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
) -> None: ) -> None:
...@@ -123,6 +125,7 @@ class LlamaAttention(nn.Module): ...@@ -123,6 +125,7 @@ class LlamaAttention(nn.Module):
max_position=max_position_embeddings, max_position=max_position_embeddings,
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=rope_is_neox_style,
) )
self.attn = RadixAttention( self.attn = RadixAttention(
self.num_heads, self.num_heads,
...@@ -160,9 +163,10 @@ class LlamaDecoderLayer(nn.Module): ...@@ -160,9 +163,10 @@ class LlamaDecoderLayer(nn.Module):
if rope_scaling is not None and getattr( if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None config, "original_max_position_embeddings", None
): ):
rope_scaling[ rope_scaling["original_max_position_embeddings"] = (
"original_max_position_embeddings" config.original_max_position_embeddings
] = config.original_max_position_embeddings )
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = LlamaAttention( self.self_attn = LlamaAttention(
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -171,6 +175,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -171,6 +175,7 @@ class LlamaDecoderLayer(nn.Module):
layer_id=layer_id, layer_id=layer_id,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
rope_is_neox_style=rope_is_neox_style,
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment