Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
184b992d
Commit
184b992d
authored
Jul 28, 2023
by
Tri Dao
Browse files
[GPT] Implement parallel LLaMa
parent
840f7925
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
176 additions
and
43 deletions
+176
-43
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+15
-2
tests/models/test_falcon.py
tests/models/test_falcon.py
+0
-2
tests/models/test_llama.py
tests/models/test_llama.py
+161
-39
No files found.
flash_attn/models/gpt.py
View file @
184b992d
...
...
@@ -527,6 +527,15 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
dim
=
x
.
shape
[
-
1
]
//
world_size
state_dict
[
key
]
=
x
[...,
rank
*
dim
:(
rank
+
1
)
*
dim
]
def
shard_gatedmlp_fc1_dim
(
state_dict
,
key
):
if
key
in
state_dict
:
x
=
state_dict
[
key
]
dim
=
x
.
shape
[
0
]
//
world_size
//
2
state_dict
[
key
]
=
rearrange
(
rearrange
(
x
,
"(two o) ... -> two o ..."
,
two
=
2
)[:,
rank
*
dim
:(
rank
+
1
)
*
dim
],
"two o ... -> (two o) ..."
)
def
shard_qkv_headdim
(
state_dict
,
key
):
if
key
in
state_dict
:
n_head
=
config
.
n_head
...
...
@@ -559,8 +568,12 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mixer.out_proj.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mixer.out_proj.bias'
,
None
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
if
config
.
activation_function
in
[
"glu"
,
"swiglu"
,
"geglu"
]:
shard_gatedmlp_fc1_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_gatedmlp_fc1_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
else
:
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.weight'
)
shard_first_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc1.bias'
)
shard_last_dim
(
state_dict
,
f
'transformer.layers.
{
i
}
.mlp.fc2.weight'
)
if
rank
!=
0
:
state_dict
.
pop
(
f
'transformer.layers.
{
i
}
.mlp.fc2.bias'
,
None
)
...
...
tests/models/test_falcon.py
View file @
184b992d
...
...
@@ -300,8 +300,6 @@ def test_falcon_parallel_generation(model_name, world_size):
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
torch
.
distributed
.
barrier
()
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
...
...
tests/models/test_llama.py
View file @
184b992d
...
...
@@ -13,12 +13,15 @@ current_dir = Path(__file__).parent.absolute()
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
LlamaConfig
,
LlamaTokenizer
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
combine_state_dicts_tp
from
flash_attn.models.gpt
import
GPTLMHeadModel
,
combine_state_dicts_tp
,
shard_state_dict_tp
from
flash_attn.models.llama
import
remap_state_dict_meta_llama
,
llama_config_to_gpt2_config
from
flash_attn.models.llama
import
config_from_checkpoint
,
state_dicts_from_checkpoint
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
...
...
@@ -38,6 +41,7 @@ def test_llama_state_dict(model_name):
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"7B"
,
"13B"
])
# @pytest.mark.parametrize('model_name', ["7B"])
def
test_llama_optimized
(
model_name
):
"""Check that our implementation of LLaMa (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
...
...
@@ -59,7 +63,7 @@ def test_llama_optimized(model_name):
pretrained_state_dicts
=
[
remap_state_dict_meta_llama
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
torch
.
manual_seed
(
0
)
...
...
@@ -86,8 +90,9 @@ def test_llama_optimized(model_name):
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
})
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
with
torch
.
no_grad
():
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
...
...
@@ -104,7 +109,6 @@ def test_llama_optimized(model_name):
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "parallel"
@
pytest
.
mark
.
skip
(
reason
=
"Tensor Parallel is not implemented for GatedMLP yet"
)
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"13B"
])
def
test_llama_parallel
(
model_name
,
world_size
):
...
...
@@ -118,7 +122,6 @@ def test_llama_parallel(model_name, world_size):
current_dir
.
parent
.
parent
/
'checkpoints'
))
/
'llama'
dtype
=
torch
.
float16
device
=
'cuda'
config
=
llama_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
))
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
...
...
@@ -139,8 +142,7 @@ def test_llama_parallel(model_name, world_size):
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
),
strict
=
False
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
eval
()
torch
.
manual_seed
(
0
)
...
...
@@ -151,39 +153,49 @@ def test_llama_parallel(model_name, world_size):
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
out
,
_
=
all_gather_raw
(
out
,
process_group
=
process_group
)
out
=
rearrange
(
out
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
=
model
(
input_ids
).
logits
logits
=
rearrange
(
logits
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
'(n b) ... d -> b ... (n d)'
,
b
=
batch_size
)
del
model
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
device_map
=
'auto'
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
torch_dtype
=
dtype
,
device_map
=
"auto"
)
model_hf
.
eval
()
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_hf
=
model_hf
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"7B"
,
"13B"
])
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
device_map
=
"auto"
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
torch_dtype
=
dtype
,
device_map
=
"auto"
)
model_hf
.
eval
()
with
torch
.
no_grad
():
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_hf
=
model_hf
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_hf
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# @pytest.mark.parametrize('model_name', ["7B", "13B"])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"7B"
])
def
test_llama_generation
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
'CHECKPOINT_DIR'
,
current_dir
.
parent
.
parent
/
'checkpoints'
))
/
'llama'
...
...
@@ -231,7 +243,7 @@ def test_llama_generation(model_name):
pretrained_state_dicts
=
[
remap_state_dict_meta_llama
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
print
(
'Without CUDA graph'
)
...
...
@@ -274,3 +286,113 @@ def test_llama_generation(model_name):
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
torch
.
equal
(
logits_cg
,
logits
)
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_llama.py -k "llama_parallel_generation"
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
2
])
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"13B"
])
def
test_llama_parallel_generation
(
model_name
,
world_size
):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
'CHECKPOINT_DIR'
,
current_dir
.
parent
.
parent
/
'checkpoints'
))
/
'llama'
dtype
=
torch
.
float16
config
=
llama_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
))
config
.
use_flash_attn
=
False
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
False
config
.
residual_in_fp32
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_meta_llama
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
))
model
.
eval
()
print
(
'Without CUDA graph'
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
'With CUDA graph'
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
)
del
model
parallel_state
.
destroy_model_parallel
()
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
torch_dtype
=
dtype
,
device_map
=
"auto"
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
with
torch
.
inference_mode
():
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
)
torch
.
cuda
.
synchronize
()
print
(
f
'Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms'
)
del
model_hf
model_ref
=
LlamaForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
f
'
{
model_name
}
-hf'
,
device_map
=
"auto"
)
model_ref
.
eval
()
with
torch
.
inference_mode
():
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
):
-
1
]
del
model_ref
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
'HF fp16 logits max diff:
{
hf_error
}
'
)
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
logits_cg
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment