Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
90184df7
Unverified
Commit
90184df7
authored
Jun 12, 2024
by
OlivierDehaene
Committed by
GitHub
Jun 12, 2024
Browse files
fix(layers): fix SuRotaryEmbedding (#2060)
* fix(layers): fix SuRotaryEmbedding * change arange * remove logs
parent
521de6ca
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
14 deletions
+15
-14
server/text_generation_server/layers/rotary.py
server/text_generation_server/layers/rotary.py
+14
-12
server/text_generation_server/models/flash_phi.py
server/text_generation_server/models/flash_phi.py
+1
-2
No files found.
server/text_generation_server/layers/rotary.py
View file @
90184df7
...
@@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
...
@@ -267,19 +267,21 @@ class SuRotaryEmbedding(PositionRotaryEmbedding):
or
self
.
_cos_cached
.
dtype
!=
dtype
or
self
.
_cos_cached
.
dtype
!=
dtype
):
):
self
.
_seq_len_cached
=
seqlen
self
.
_seq_len_cached
=
seqlen
if
seqlen
>
self
.
original_max_position_embeddings
:
inv_freq
=
self
.
long_inv_freq
else
:
inv_freq
=
self
.
short_inv_freq
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
inv_freq
.
dtype
)
if
self
.
scaling_factor
is
not
None
:
t
/=
self
.
scaling_factor
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs
=
torch
.
outer
(
t
,
inv_freq
.
to
(
device
=
t
.
device
))
t
=
torch
.
arange
(
seqlen
,
device
=
device
,
dtype
=
self
.
short_inv_freq
.
dtype
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
dtype
)
short_freqs
=
torch
.
outer
(
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
dtype
)
t
[:
self
.
original_max_position_embeddings
],
self
.
short_inv_freq
.
to
(
device
=
t
.
device
),
)
long_freqs
=
torch
.
outer
(
t
[
self
.
original_max_position_embeddings
:],
self
.
long_inv_freq
.
to
(
device
=
t
.
device
),
)
freqs
=
torch
.
cat
([
short_freqs
,
long_freqs
])
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
self
.
scaling_factor
).
to
(
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
self
.
scaling_factor
).
to
(
dtype
)
class
DynamicPositionRotaryEmbedding
(
PositionRotaryEmbedding
):
class
DynamicPositionRotaryEmbedding
(
PositionRotaryEmbedding
):
...
...
server/text_generation_server/models/flash_phi.py
View file @
90184df7
...
@@ -8,7 +8,6 @@ from typing import Optional
...
@@ -8,7 +8,6 @@ from typing import Optional
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_phi_modeling
import
(
from
text_generation_server.models.custom_modeling.flash_phi_modeling
import
(
FlashPhiForCausalLM
,
FlashPhiForCausalLM
,
PhiConfig
,
)
)
from
text_generation_server.utils
import
(
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
initialize_torch_distributed
,
...
@@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM):
...
@@ -44,7 +43,7 @@ class FlashPhi(FlashCausalLM):
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
)
)
config
=
Phi
Config
.
from_pretrained
(
config
=
Auto
Config
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
)
config
.
quantize
=
quantize
config
.
quantize
=
quantize
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment