Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ec9f74ab
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "f9e3c47d4a536205cc0c105f1ef3e2d9469460ce"
Commit
ec9f74ab
authored
Jul 22, 2023
by
Tri Dao
Browse files
[Rotary] Don't store inv_freq in state_dict
parent
a157cc8c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
3 deletions
+2
-3
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+2
-2
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+0
-1
No files found.
flash_attn/layers/rotary.py
View file @
ec9f74ab
...
@@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -191,12 +191,12 @@ class RotaryEmbedding(torch.nn.Module):
self
.
pos_idx_in_fp32
=
pos_idx_in_fp32
self
.
pos_idx_in_fp32
=
pos_idx_in_fp32
# Generate and save the inverse frequency buffer (non trainable)
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
self
.
_compute_inv_freq
(
device
)
inv_freq
=
self
.
_compute_inv_freq
(
device
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
,
persistent
=
False
)
self
.
interleaved
=
interleaved
self
.
interleaved
=
interleaved
self
.
scale_base
=
scale_base
self
.
scale_base
=
scale_base
scale
=
((
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
scale
=
((
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
is
not
None
else
None
)
/
(
1.4
*
dim
)
if
scale_base
is
not
None
else
None
)
self
.
register_buffer
(
"scale"
,
scale
)
self
.
register_buffer
(
"scale"
,
scale
,
persistent
=
False
)
self
.
_seq_len_cached
=
0
self
.
_seq_len_cached
=
0
self
.
_cos_cached
=
None
self
.
_cos_cached
=
None
...
...
flash_attn/models/gpt.py
View file @
ec9f74ab
...
@@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -237,7 +237,6 @@ class GPTPreTrainedModel(nn.Module):
state_dict
=
remap_state_dict_hf_opt
(
state_dict
,
config
)
state_dict
=
remap_state_dict_hf_opt
(
state_dict
,
config
)
elif
model_name
.
startswith
(
'EleutherAI/gpt-j-'
):
elif
model_name
.
startswith
(
'EleutherAI/gpt-j-'
):
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
state_dict
=
remap_state_dict_hf_gptj
(
state_dict
,
config
)
strict
=
False
# We have rotary_emb.inf_freq buffers not in the GPT-J checkpoint
elif
model_name
.
startswith
(
'EleutherAI/gpt-neox-'
):
elif
model_name
.
startswith
(
'EleutherAI/gpt-neox-'
):
state_dict
=
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
)
state_dict
=
remap_state_dict_hf_gpt_neox
(
state_dict
,
config
)
else
:
else
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment