Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
78b7a1dc
Commit
78b7a1dc
authored
Jan 22, 2023
by
Tri Dao
Browse files
[OPT] Load fp16 weights on CPU before moving to GPU
parent
33e0860c
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
27 additions
and
12 deletions
+27
-12
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+3
-3
flash_attn/models/opt.py
flash_attn/models/opt.py
+2
-0
flash_attn/utils/generation.py
flash_attn/utils/generation.py
+11
-4
flash_attn/utils/pretrained.py
flash_attn/utils/pretrained.py
+6
-2
tests/models/test_gpt_generation.py
tests/models/test_gpt_generation.py
+3
-3
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+2
-0
No files found.
flash_attn/models/gpt.py
View file @
78b7a1dc
...
...
@@ -166,9 +166,10 @@ class GPTPreTrainedModel(nn.Module):
"""
# Instantiate model.
model
=
cls
(
config
,
*
args
,
device
=
device
,
dtype
=
dtype
,
**
kwargs
)
# If we're going to shard the model, then don't load fp32 weights to GPU.
# Load state_dict in cpu because we already initialized the model in GPU, and we don't
# want extra stuff taking up more GPU memory
state_dict
=
state_dict_from_pretrained
(
model_name
,
device
=
device
if
world_size
==
1
else
None
,
dtype
=
dtype
model_name
,
device
=
'cpu'
,
dtype
=
dtype
)
if
model_name
.
startswith
(
'gpt2'
):
state_dict
=
remap_state_dict_gpt2
(
state_dict
,
config
)
...
...
@@ -178,7 +179,6 @@ class GPTPreTrainedModel(nn.Module):
raise
NotImplementedError
(
f
'Model
{
model_name
}
not supported'
)
if
world_size
>
1
:
state_dict
=
shard_state_dict_tp
(
state_dict
,
config
,
world_size
,
rank
)
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
load_return
=
model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
logger
.
info
(
load_return
)
return
model
...
...
flash_attn/models/opt.py
View file @
78b7a1dc
...
...
@@ -43,6 +43,8 @@ def remap_state_dict_opt(state_dict, config):
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
'^transformer.final_layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
# The OPT-175B checkpoint calls this 'decoder.layer_norm' instead of 'decoder.final_layer_norm'
key
=
re
.
sub
(
r
'^transformer.layer_norm.'
,
r
'transformer.ln_f.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).self_attn_layer_norm.'
,
r
'transformer.layers.\1.norm1.'
,
key
)
key
=
re
.
sub
(
r
'^transformer.layers.(\d+).final_layer_norm.'
,
...
...
flash_attn/utils/generation.py
View file @
78b7a1dc
...
...
@@ -196,7 +196,7 @@ class DecodingCGCache:
@
torch
.
inference_mode
()
def
update_graph_cache
(
model
,
cache
,
batch_size
,
seqlen_og
,
max_seqlen
,
tensor_parallel
=
1
,
dtype
=
None
):
dtype
=
None
,
n_warmups
=
2
):
if
cache
is
None
:
cache
=
DecodingCGCache
()
param_example
=
next
(
iter
(
model
.
parameters
()))
...
...
@@ -228,7 +228,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
if
s_type
not
in
cache
.
callables
:
seqlen
=
min
(
max
(
seqlen_og
,
seqlen_type_to_seqlen
(
s_type
)),
max_seqlen
)
cache
.
callables
[
s_type
]
=
capture_graph
(
model
,
cache
.
inference_params
,
batch_size
,
seqlen_og
,
seqlen
,
mempool
=
cache
.
mempool
model
,
cache
.
inference_params
,
batch_size
,
seqlen_og
,
seqlen
,
mempool
=
cache
.
mempool
,
n_warmups
=
n_warmups
)
def
dispatch
(
input_ids
,
position_ids
,
seqlen
):
...
...
@@ -239,7 +240,8 @@ def update_graph_cache(model, cache, batch_size, seqlen_og, max_seqlen, tensor_p
return
cache
def
capture_graph
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_seqlen
,
mempool
=
None
):
def
capture_graph
(
model
,
inference_params
,
batch_size
,
seqlen_og
,
max_seqlen
,
mempool
=
None
,
n_warmups
=
2
):
assert
max_seqlen
>=
seqlen_og
device
=
next
(
iter
(
model
.
parameters
())).
device
input_ids
=
torch
.
full
((
batch_size
,
1
),
0
,
dtype
=
torch
.
long
,
device
=
device
)
...
...
@@ -250,10 +252,15 @@ def capture_graph(model, inference_params, batch_size, seqlen_og, max_seqlen, me
s
=
torch
.
cuda
.
Stream
()
s
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
s
):
for
_
in
range
(
2
):
for
_
in
range
(
n_warmups
):
logits
=
model
(
input_ids
,
position_ids
=
position_ids
,
inference_params
=
inference_params
).
logits
[:,
-
1
]
s
.
synchronize
()
# This might be needed for correctness if we run with NCCL_GRAPH_MIXING_SUPPORT=0,
# which requires that graph launch and non-captured launch to not overlap (I think,
# that's how I interpret the documentation). I'm not sure if this is required.
if
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
barrier
()
torch
.
cuda
.
current_stream
().
wait_stream
(
s
)
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
...
...
flash_attn/utils/pretrained.py
View file @
78b7a1dc
...
...
@@ -7,6 +7,8 @@ from transformers.utils.hub import cached_file, get_checkpoint_shard_files
def
state_dict_from_pretrained
(
model_name
,
device
=
None
,
dtype
=
None
):
# If not fp32, then we don't want to load directly to the GPU
mapped_device
=
'cpu'
if
dtype
not
in
[
torch
.
float32
,
None
]
else
device
is_sharded
=
False
resolved_archive_file
=
cached_file
(
model_name
,
WEIGHTS_NAME
,
_raise_exceptions_for_missing_entries
=
False
)
...
...
@@ -25,9 +27,11 @@ def state_dict_from_pretrained(model_name, device=None, dtype=None):
)
state_dict
=
{}
for
sharded_file
in
resolved_archive_file
:
state_dict
.
update
(
torch
.
load
(
sharded_file
,
map_location
=
device
))
state_dict
.
update
(
torch
.
load
(
sharded_file
,
map_location
=
mapped_
device
))
else
:
state_dict
=
torch
.
load
(
cached_file
(
model_name
,
WEIGHTS_NAME
),
map_location
=
device
)
# Convert dtype before moving to GPU to save memory
if
dtype
is
not
None
:
state_dict
=
{
k
:
v
.
to
(
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
dtype
=
dtype
)
for
k
,
v
in
state_dict
.
items
()}
state_dict
=
{
k
:
v
.
to
(
device
=
device
)
for
k
,
v
in
state_dict
.
items
()}
return
state_dict
tests/models/test_gpt_generation.py
View file @
78b7a1dc
...
...
@@ -114,7 +114,7 @@ def test_greedy_decode_gpt2(model_name, rotary, optimized, fused_ft_kernel):
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"facebook/opt-125m"
,
"facebook/opt-350m"
,
"facebook/opt-1.3b"
,
"facebook/opt-2.7b"
,
"facebook/opt-6.7b"
])
# @pytest.mark.parametrize('model_name', ["facebook/opt-
6.7b
"])
# @pytest.mark.parametrize('model_name', ["facebook/opt-
125m
"])
def
test_greedy_decode_opt
(
model_name
):
"""Check that our implementation of OPT generation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
...
...
@@ -145,7 +145,7 @@ def test_greedy_decode_opt(model_name):
input_ids
=
tokenizer
(
"Hello, my dog is cute and"
,
return_tensors
=
"pt"
).
input_ids
.
to
(
device
=
device
)
max_length
=
3
0
max_length
=
6
0
# input_ids = torch.randint(0, 100, (2, 10), dtype=torch.long, device='cuda')
# max_length = input_ids.shape[1] + 40
...
...
@@ -192,7 +192,7 @@ def test_greedy_decode_opt(model_name):
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
if
verbose
:
print
(
out_cg
.
sequences
)
print
(
tokenizer
.
batch_decode
(
out
.
sequences
.
tolist
()))
print
(
tokenizer
.
batch_decode
(
out
_cg
.
sequences
.
tolist
()))
del
model
...
...
tests/models/test_gpt_generation_parallel.py
View file @
78b7a1dc
...
...
@@ -129,3 +129,5 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
assert
torch
.
all
(
out
.
sequences
==
out_hf
.
sequences
)
assert
(
torch
.
stack
(
out
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)).
abs
().
max
().
item
()
parallel_state
.
destroy_model_parallel
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment