Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
78b7a1dc
Commit
78b7a1dc
authored
Jan 22, 2023
by
Tri Dao
Browse files
[OPT] Load fp16 weights on CPU before moving to GPU
parent
33e0860c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
12 deletions
+27
-12
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+3
-3
flash_attn/models/opt.py
flash_attn/models/opt.py
+2
-0
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+11
-4
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+6
-2
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+3
-3
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+2
-0
No files found.
flash_attn/models/gpt.py
View file @
78b7a1dc
...
@@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module):
"""
"""
# Instantiate model.
# Instantiate model.
model
=
cls
(
config
,
*
args
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
model
=
cls
(
config
,
*
args
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
# If we're going to shard the model, then don't load fp32 weights to GPU.
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory
state_dict
=
state_dict_from_pretrained
(
state_dict
=
state_dict_from_pretrained
(
model_name
,
device
=
device
if
world_size
==
1
else
None
,
dtype
=
dtype
model_name
,
device
=
'cpu'
,
dtype
=
dtype
)
)
if
model_name
.
startswith
(
'gpt2'
):
if
model_name
.
startswith
(
'gpt2'
):
state_dict
=
remap_state_dict_gpt2
(
state_dict
,
config
)
state_dict
=
remap_state_dict_gpt2
(
state_dict
,
config
)
...
@@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module):
...
@@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module):
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
if
world_size
>
1
:
if
world_size
>
1
:
state_dict
=
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
)
state_dict
=
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
)
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
load_return
=
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
load_return
=
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
logger
.
info
(
load_return
)
logger
.
info
(
load_return
)
return
model
return
model
...
...
flash_attn/models/opt.py
View file @
78b7a1dc
...
@@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config):
...
@@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config):
# LayerNorm
# LayerNorm
def
key_mapping_ln
(
key
):
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.final_layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.final_layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key
=
re
.
sub
(
r
'^transformer.layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn_layer_norm.'
,
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn_layer_norm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).final_layer_norm.'
,
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).final_layer_norm.'
,
...
...
flash_attn/utils/generation.py
View file @
78b7a1dc
...
@@ -196,7 +196,7 @@ class DecodingCGCache:
...
@@ -196,7 +196,7 @@ class DecodingCGCache:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
):
dtype
=
None
,
n_warmups
=
2
):
if
cache
is
None
:
if
cache
is
None
:
cache
=
DecodingCGCache
()
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
param_example
=
next
(
iter
(
model
.
parameters
()))
...
@@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
...
@@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
if
s_type
not
in
cache
.
callables
:
if
s_type
not
in
cache
.
callables
:
seqlen
=
min
(
max
(
seqlen_og
,
seqlen_type_to_seqlen
(
s_type
)),
max_seqlen
)
seqlen
=
min
(
max
(
seqlen_og
,
seqlen_type_to_seqlen
(
s_type
)),
max_seqlen
)
cache
.
callables
[
s_type
]
=
capture_graph
(
cache
.
callables
[
s_type
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
seqlen_og
,
seqlen
,
mempool
=
cache
.
mempool
model
,
cache
.
inference_params
,
batch_size
,
seqlen_og
,
seqlen
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
)
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
...
@@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
...
@@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
return
cache
return
cache
def
capture_graph
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_seqlen
,
mempool
=
None
):
def
capture_graph
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_seqlen
,
mempool
=
None
,
n_warmups
=
2
):
assert
max_seqlen
>=
seqlen_og
assert
max_seqlen
>=
seqlen_og
device
=
next
(
iter
(
model
.
parameters
())).
device
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
...
@@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
...
@@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
s
=
torch
.
cuda
.
Stream
()
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
2
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
inference_params
=
inference_params
).
logits
[:,
-
1
]
s
.
synchronize
()
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Captures the graph
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
# To allow capture, automatically sets a side stream as the current stream in the context
...
...
flash_attn/utils/pretrained.py
View file @
78b7a1dc
...
@@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
...
@@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
# If not fp32, then we don't want to load directly to the GPU
mapped_device
=
'cpu'
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
is_sharded
=
False
is_sharded
=
False
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
_raise_exceptions_for_missing_entries
=
False
)
...
@@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
...
@@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
)
)
state_dict
=
{}
state_dict
=
{}
for
sharded_file
in
resolved_archive_file
:
for
sharded_file
in
resolved_archive_file
:
state_dict
.
update
(
torch
.
load
(
sharded_file
,
map_location
=
device
))
state_dict
.
update
(
torch
.
load
(
sharded_file
,
map_location
=
mapped_
device
))
else
:
else
:
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
# Convert dtype before moving to GPU to save memory
if
dtype
is
not
None
:
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
dtype
=
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
return
state_dict
tests/models/test_gpt_generation.py
View file @
78b7a1dc
...
@@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
...
@@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
,
"facebook/opt-2.7b"
,
"facebook/opt-6.7b"
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
,
"facebook/opt-2.7b"
,
"facebook/opt-6.7b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-
6.7b
"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-
125m
"])
def
test_greedy_decode_opt
(
model_name
):
def
test_greedy_decode_opt
(
model_name
):
"""Check that our implementation of OPT generation matches the HF implementation:
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...
@@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name):
...
@@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name):
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
3
0
max_length
=
6
0
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
# max_length = input_ids.shape[1] + 40
...
@@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name):
...
@@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name):
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
if
verbose
:
if
verbose
:
print
(
out_cg
.
sequences
)
print
(
out_cg
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
print
(
tokenizer
.
batch_decode
(
out
_cg
.
sequences
.
tolist
()))
del
model
del
model
...
...
tests/models/test_gpt_generation_parallel.py
View file @
78b7a1dc
...
@@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
...
@@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
parallel_state
.
destroy_model_parallel
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment