Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
187c2a06
Unverified
Commit
187c2a06
authored
Sep 22, 2023
by
Yuchao Dai
Committed by
GitHub
Sep 21, 2023
Browse files
Fix E1136 (#563)
parent
229080b9
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
10 additions
and
9 deletions
+10
-9
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+2
-1
flash_attn/models/llama.py
flash_attn/models/llama.py
+8
-8
No files found.
flash_attn/models/gpt.py
View file @
187c2a06
...
@@ -6,6 +6,7 @@ import re
...
@@ -6,6 +6,7 @@ import re
from
collections
import
OrderedDict
,
namedtuple
from
collections
import
OrderedDict
,
namedtuple
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
functools
import
partial
from
functools
import
partial
from
typing
import
Dict
,
List
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -810,7 +811,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
...
@@ -810,7 +811,7 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
return
state_dict
return
state_dict
def
combine_state_dicts_tp
(
state_dicts
:
l
ist
[
d
ict
[
str
,
torch
.
Tensor
]],
config
:
GPT2Config
):
def
combine_state_dicts_tp
(
state_dicts
:
L
ist
[
D
ict
[
str
,
torch
.
Tensor
]],
config
:
GPT2Config
):
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
"""Convert the list of sharded state_dict of a GPT model with tensor parallel to
the state_dict of a standard GPT model.
the state_dict of a standard GPT model.
...
...
flash_attn/models/llama.py
View file @
187c2a06
...
@@ -6,7 +6,7 @@ import os
...
@@ -6,7 +6,7 @@ import os
import
re
import
re
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Union
from
typing
import
Dict
,
List
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -17,8 +17,8 @@ from einops import rearrange
...
@@ -17,8 +17,8 @@ from einops import rearrange
def
remap_state_dict_meta_llama
(
def
remap_state_dict_meta_llama
(
state_dict
:
d
ict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
state_dict
:
D
ict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
d
ict
[
str
,
torch
.
Tensor
]:
)
->
D
ict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Meta format to standard GPT format.
"""Convert the state_dict in Meta format to standard GPT format.
This function modifies state_dict in place.
This function modifies state_dict in place.
...
@@ -113,8 +113,8 @@ def remap_state_dict_meta_llama(
...
@@ -113,8 +113,8 @@ def remap_state_dict_meta_llama(
def
remap_state_dict_hf_llama
(
def
remap_state_dict_hf_llama
(
state_dict
:
d
ict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
state_dict
:
D
ict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
d
ict
[
str
,
torch
.
Tensor
]:
)
->
D
ict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in Hugging Face format to standard GPT format.
"""Convert the state_dict in Hugging Face format to standard GPT format.
This function modifies state_dict in place.
This function modifies state_dict in place.
...
@@ -217,8 +217,8 @@ def remap_state_dict_hf_llama(
...
@@ -217,8 +217,8 @@ def remap_state_dict_hf_llama(
def
inv_remap_state_dict_hf_llama
(
def
inv_remap_state_dict_hf_llama
(
state_dict
:
d
ict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
state_dict
:
D
ict
[
str
,
torch
.
Tensor
],
config
:
GPT2Config
)
->
d
ict
[
str
,
torch
.
Tensor
]:
)
->
D
ict
[
str
,
torch
.
Tensor
]:
"""Convert the state_dict in standard GPT format to Hugging Face format.
"""Convert the state_dict in standard GPT format to Hugging Face format.
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
This function is meant to be the inverse of remap_state_dict_hf_llama, up to a
...
@@ -382,7 +382,7 @@ def config_from_checkpoint(
...
@@ -382,7 +382,7 @@ def config_from_checkpoint(
def
state_dicts_from_checkpoint
(
def
state_dicts_from_checkpoint
(
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
checkpoint_path
:
Union
[
str
,
os
.
PathLike
],
model_name
:
str
)
->
l
ist
[
dict
]:
)
->
L
ist
[
dict
]:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
return
[
torch
.
load
(
path
,
map_location
=
"cpu"
)
torch
.
load
(
path
,
map_location
=
"cpu"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment