Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
df845b2b
Unverified
Commit
df845b2b
authored
Aug 19, 2024
by
Woosuk Kwon
Committed by
GitHub
Aug 19, 2024
Browse files
[Misc] Remove Gemma RoPE (#7638)
parent
1a36287b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
7 additions
and
26 deletions
+7
-26
vllm/model_executor/layers/rotary_embedding.py
vllm/model_executor/layers/rotary_embedding.py
+0
-15
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+3
-5
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+4
-6
No files found.
vllm/model_executor/layers/rotary_embedding.py
View file @
df845b2b
...
@@ -93,11 +93,6 @@ class RotaryEmbedding(CustomOp):
...
@@ -93,11 +93,6 @@ class RotaryEmbedding(CustomOp):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
"""Compute the inverse frequency."""
"""Compute the inverse frequency."""
# NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`.
# However, we use `torch.arange(..., dtype=torch.float)` instead to
# avoid numerical issues with large base values (e.g., 10000000).
# This may cause a slight numerical difference between the HF
# implementation and ours.
# NOTE(woosuk): To exactly match the HF implementation, we need to
# NOTE(woosuk): To exactly match the HF implementation, we need to
# use CPU to compute the cache and then move it to GPU. However, we
# use CPU to compute the cache and then move it to GPU. However, we
# create the cache on GPU for faster initialization. This may cause
# create the cache on GPU for faster initialization. This may cause
...
@@ -724,16 +719,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
...
@@ -724,16 +719,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
return
query
,
key
return
query
,
key
class
GemmaRotaryEmbedding
(
RotaryEmbedding
):
def
_compute_inv_freq
(
self
,
base
:
Union
[
int
,
float
])
->
torch
.
Tensor
:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
int64
).
float
()
/
self
.
rotary_dim
))
return
inv_freq
class
Llama3RotaryEmbedding
(
RotaryEmbedding
):
class
Llama3RotaryEmbedding
(
RotaryEmbedding
):
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/models/gemma.py
View file @
df845b2b
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -33,7 +33,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
GemmaRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -148,14 +148,12 @@ class GemmaAttention(nn.Module):
...
@@ -148,14 +148,12 @@ class GemmaAttention(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
# TODO(woosuk): Use the `get_rope` interface.
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
GemmaRotaryEmbedding
(
self
.
head_dim
,
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
_embeddings
=
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
is_neox_style
=
True
,
dtype
=
torch
.
get_default_dtype
(),
)
)
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
attn
=
Attention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
head_dim
,
...
...
vllm/model_executor/models/gemma2.py
View file @
df845b2b
...
@@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
...
@@ -32,7 +32,7 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.rotary_embedding
import
GemmaRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -130,14 +130,12 @@ class Gemma2Attention(nn.Module):
...
@@ -130,14 +130,12 @@ class Gemma2Attention(nn.Module):
bias
=
config
.
attention_bias
,
bias
=
config
.
attention_bias
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
)
)
# TODO(woosuk): Use the `get_rope` interface.
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
GemmaRotaryEmbedding
(
self
.
head_dim
,
self
.
head_dim
,
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position_embeddings
,
max_position
=
max_position_embeddings
,
base
=
self
.
rope_theta
,
base
=
self
.
rope_theta
,
is_neox_style
=
True
,
is_neox_style
=
True
,
dtype
=
torch
.
get_default_dtype
(),
)
)
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment