Unverified Commit 7bedab57 authored by Qing's avatar Qing Committed by GitHub
Browse files

Add rope_scaling to Qwen (#1210)

parent 20f7cc4c
......@@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
"""
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
......@@ -76,13 +76,12 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module):
def __init__(
self,
def __init__(self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
):
rope_scaling: Optional[Dict[str, Any]] = None):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
......@@ -116,7 +115,7 @@ class QWenAttention(nn.Module):
rotary_dim=self.head_dim,
base=rope_theta,
max_position=max_position_embeddings,
)
rope_scaling=rope_scaling)
def forward(
self,
......@@ -144,10 +143,12 @@ class QWenBlock(nn.Module):
self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
self.attn = QWenAttention(config.hidden_size,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta)
rope_theta=rope_theta,
rope_scaling=rope_scaling)
self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment