Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
0a146185
Commit
0a146185
authored
Dec 19, 2023
by
Tri Dao
Browse files
[Gen] Remove minor dead code
parent
e4f726fc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1 addition
and
5 deletions
+1
-5
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+1
-1
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+0
-4
No files found.
flash_attn/models/gpt_neox.py
View file @
0a146185
...
@@ -27,7 +27,7 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
...
@@ -27,7 +27,7 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
)
if
getattr
(
config
,
"tie_word_embeddings"
):
if
getattr
(
config
,
"tie_word_embeddings"
,
False
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
else
:
output_embeddings
=
state_dict
.
pop
(
"embed_out.weight"
)
output_embeddings
=
state_dict
.
pop
(
"embed_out.weight"
)
...
...
flash_attn/utils/generation.py
View file @
0a146185
...
@@ -591,10 +591,6 @@ def allocate_inference_cache(
...
@@ -591,10 +591,6 @@ def allocate_inference_cache(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
):
):
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
assert
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float32
]
packsize
=
4
if
dtype
==
torch
.
float32
else
8
assert
headdim
%
packsize
==
0
k_cache_shape
=
(
max_batch_size
,
nheads
,
headdim
//
packsize
,
max_seqlen
,
packsize
)
v_cache_shape
=
(
max_batch_size
,
nheads
,
max_seqlen
,
headdim
)
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
kv_cache_shape
=
(
max_batch_size
,
max_seqlen
,
2
,
nheads
,
headdim
)
if
isinstance
(
layers
,
int
):
if
isinstance
(
layers
,
int
):
layers
=
range
(
layers
)
layers
=
range
(
layers
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment