"vscode:/vscode.git/clone" did not exist on "4795e1634f8eb734e6f8b55f8f1840782d3bbe35"
Unverified Commit 78e5b22f authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: use get_rope for gemma2 (#2954)

parent 7a15e9ad
...@@ -20,6 +20,7 @@ from typing import Iterable, Optional, Set, Tuple, Union ...@@ -20,6 +20,7 @@ from typing import Iterable, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import GeluAndMul from sglang.srt.layers.activation import GeluAndMul
...@@ -48,19 +49,6 @@ def get_attention_sliding_window_size(config): ...@@ -48,19 +49,6 @@ def get_attention_sliding_window_size(config):
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class GemmaRotaryEmbedding(RotaryEmbedding):
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
/ self.rotary_dim
)
)
return inv_freq
class Gemma2MLP(nn.Module): class Gemma2MLP(nn.Module):
def __init__( def __init__(
self, self,
...@@ -143,14 +131,12 @@ class Gemma2Attention(nn.Module): ...@@ -143,14 +131,12 @@ class Gemma2Attention(nn.Module):
bias=config.attention_bias, bias=config.attention_bias,
quant_config=quant_config, quant_config=quant_config,
) )
# from vLLM: TODO(woosuk): Use the `get_rope` interface. self.rotary_emb = get_rope(
self.rotary_emb = GemmaRotaryEmbedding(
self.head_dim,
self.head_dim, self.head_dim,
max_position_embeddings, rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=self.rope_theta, base=self.rope_theta,
is_neox_style=True, is_neox_style=True,
dtype=torch.get_default_dtype(),
) )
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window") use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment