Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
7bedab57
Unverified
Commit
7bedab57
authored
Sep 28, 2023
by
Qing
Committed by
GitHub
Sep 28, 2023
Browse files
Add rope_scaling to Qwen (#1210)
parent
20f7cc4c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+11
-10
No files found.
vllm/model_executor/models/qwen.py
View file @
7bedab57
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
The input of the model is flattened to a 1D tensor of tokens. The model uses
The input of the model is flattened to a 1D tensor of tokens. The model uses
InputMetadata to extract the original 2D shape of the input.
InputMetadata to extract the original 2D shape of the input.
"""
"""
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -76,13 +76,12 @@ class QWenMLP(nn.Module):
...
@@ -76,13 +76,12 @@ class QWenMLP(nn.Module):
class
QWenAttention
(
nn
.
Module
):
class
QWenAttention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
:
int
,
hidden_size
:
int
,
num_heads
:
int
,
num_heads
:
int
,
max_position_embeddings
:
int
,
max_position_embeddings
:
int
,
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
tensor_model_parallel_world_size
=
get_tensor_model_parallel_world_size
(
...
@@ -116,7 +115,7 @@ class QWenAttention(nn.Module):
...
@@ -116,7 +115,7 @@ class QWenAttention(nn.Module):
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
base
=
rope_theta
,
base
=
rope_theta
,
max_position
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
)
rope_scaling
=
rope_scaling
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -144,10 +143,12 @@ class QWenBlock(nn.Module):
...
@@ -144,10 +143,12 @@ class QWenBlock(nn.Module):
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_1
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
self
.
attn
=
QWenAttention
(
config
.
hidden_size
,
self
.
attn
=
QWenAttention
(
config
.
hidden_size
,
config
.
num_attention_heads
,
config
.
num_attention_heads
,
config
.
max_position_embeddings
,
config
.
max_position_embeddings
,
rope_theta
=
rope_theta
)
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
ln_2
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment