Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a8c35b4f
"src/lib/vscode:/vscode.git/clone" did not exist on "3f53abb2335170614858c7540c51c0e620b666e0"
Unverified
Commit
a8c35b4f
authored
Aug 22, 2023
by
GAOXinyu
Committed by
GitHub
Aug 21, 2023
Browse files
FEAT: add codes which supporting for baichuan-inc/Baichuan-7B (#425)
parent
25d6b1db
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
669 additions
and
0 deletions
+669
-0
flash_attn/models/baichuan.py
flash_attn/models/baichuan.py
+161
-0
tests/models/test_baichuan.py
tests/models/test_baichuan.py
+508
-0
No files found.
flash_attn/models/baichuan.py
0 → 100644
View file @
a8c35b4f
# Copyright (c) 2023, GGGGGGXY.
import
math
import
json
import
re
from
pathlib
import
Path
from
collections
import
OrderedDict
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
transformers
import
GPT2Config
,
AutoConfig
,
PretrainedConfig
# only support Baichuan-7B now
def
remap_state_dict_hf_baichuan
(
state_dict
,
config
):
def
key_mapping_layers
(
key
):
return
re
.
sub
(
r
"^model."
,
"transformer."
,
key
)
state_dict
=
OrderedDict
((
key_mapping_layers
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Word embedding
def
key_mapping_emb
(
key
):
return
re
.
sub
(
r
"^transformer.embed_tokens."
,
"transformer.embeddings.word_embeddings."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
"transformer.embeddings.word_embeddings.weight"
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
pad_vocab_size_multiple
=
getattr
(
config
,
"pad_vocab_size_multiple"
,
1
)
vocab_size
=
(
math
.
ceil
(
word_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
word_embeddings
.
shape
[
0
])
)
if
getattr
(
config
,
"tie_word_embeddings"
):
state_dict
[
"lm_head.weight"
]
=
state_dict
[
"transformer.embeddings.word_embeddings.weight"
]
else
:
output_embeddings
=
state_dict
.
pop
(
"lm_head.weight"
)
# Need to recompute vocab_size since Baichuan shards the word embeddings and output embeddings
# differently.
vocab_size
=
(
math
.
ceil
(
output_embeddings
.
shape
[
0
]
/
pad_vocab_size_multiple
)
*
pad_vocab_size_multiple
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
"lm_head.weight"
]
=
F
.
pad
(
output_embeddings
,
(
0
,
0
,
0
,
vocab_size
-
output_embeddings
.
shape
[
0
])
)
# LayerNorm
def
key_mapping_ln
(
key
):
key
=
re
.
sub
(
r
"^transformer.norm."
,
r
"transformer.ln_f."
,
key
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).input_layernorm."
,
r
"transformer.layers.\1.norm1."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).post_attention_layernorm."
,
r
"transformer.layers.\1.norm2."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_ln
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# MLP
for
l
in
range
(
config
.
n_layer
):
w1
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mlp.gate_proj.weight"
)
w3
=
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.mlp.up_proj.weight"
)
# Our ordering is different
state_dict
[
f
"transformer.layers.
{
l
}
.mlp.fc1.weight"
]
=
torch
.
cat
(
[
w3
,
w1
],
dim
=
0
)
def
key_mapping_mlp
(
key
):
return
re
.
sub
(
r
"^transformer.layers.(\d+).mlp.down_proj."
,
r
"transformer.layers.\1.mlp.fc2."
,
key
,
)
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn.W_pack."
,
r
"transformer.layers.\1.mixer.Wqkv."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).self_attn.o_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
for
l
in
range
(
config
.
n_layer
):
# pop rotary_emb.inv_freq from state dict
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.self_attn.rotary_emb.inv_freq"
)
return
state_dict
def
config_from_checkpoint
(
checkpoint_path
:
str
,
model_name
:
str
)
->
PretrainedConfig
:
"""Load a BaiChuanConfig from a checkpoint path."""
config
=
AutoConfig
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
trust_remote_code
=
True
)
return
config
def
state_dicts_from_checkpoint
(
checkpoint_path
:
str
,
model_name
:
str
)
->
dict
:
# Need to sort, otherwise we mess up the ordering and the weights are wrong
return
[
torch
.
load
(
path
,
map_location
=
"cpu"
)
for
path
in
sorted
(
(
Path
(
checkpoint_path
)
/
model_name
).
glob
(
"pytorch_model*.bin"
)
)
]
def
baichuan_config_to_gpt2_config
(
baichuan_config
:
PretrainedConfig
)
->
GPT2Config
:
return
GPT2Config
(
vocab_size
=
baichuan_config
.
vocab_size
,
n_positions
=
0
,
# No absolute position embedding
n_embd
=
baichuan_config
.
hidden_size
,
n_layer
=
baichuan_config
.
num_hidden_layers
,
n_head
=
baichuan_config
.
num_attention_heads
,
n_inner
=
baichuan_config
.
intermediate_size
,
activation_function
=
"swiglu"
,
# Hardcode since HF calls it 'silu'
# baichuan doesn't have dropout, idk if it's because they only release the inference code
resid_pdrop
=
0.0
,
embd_pdrop
=
0.0
,
attn_pdrop
=
0.0
,
layer_norm_epsilon
=
baichuan_config
.
rms_norm_eps
,
initializer_range
=
baichuan_config
.
initializer_range
,
bos_token_id
=
baichuan_config
.
bos_token_id
,
eos_token_id
=
baichuan_config
.
eos_token_id
,
# These are new arguments not in the original GPT2Config
pad_token_id
=
baichuan_config
.
pad_token_id
,
# Idk if this does anything
rms_norm
=
True
,
rotary_emb_fraction
=
1.0
,
rotary_emb_interleaved
=
False
,
tie_word_embeddings
=
False
,
qkv_proj_bias
=
False
,
out_proj_bias
=
False
,
mlp_fc1_bias
=
False
,
mlp_fc2_bias
=
False
,
)
tests/models/test_baichuan.py
0 → 100644
View file @
a8c35b4f
import
os
import
time
from
pathlib
import
Path
current_dir
=
Path
(
__file__
).
parent
.
absolute
()
import
torch
import
pytest
from
einops
import
rearrange
from
transformers
import
AutoConfig
,
AutoTokenizer
,
AutoModelForCausalLM
from
flash_attn.models.gpt
import
(
GPTLMHeadModel
,
combine_state_dicts_tp
,
shard_state_dict_tp
,
)
from
flash_attn.models.baichuan
import
(
remap_state_dict_hf_baichuan
,
baichuan_config_to_gpt2_config
,
)
from
flash_attn.models.baichuan
import
(
config_from_checkpoint
,
state_dicts_from_checkpoint
,
)
from
flash_attn.utils.distributed
import
all_gather_raw
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
from
flash_attn.utils.generation
import
update_graph_cache
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
def
test_baichuan_state_dict
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
config
=
baichuan_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
)
)
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dict
=
remap_state_dict_hf_baichuan
(
ckpt_state_dicts
[
0
],
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
"meta"
)
# Without device='meta' init is very slow
state_dict
=
model
.
state_dict
()
assert
len
(
state_dict
.
keys
())
==
len
(
pretrained_state_dict
.
keys
())
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
def
test_baichuan_optimized
(
model_name
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
)
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
logits
=
model
(
input_ids
).
logits
del
model
# Without device_map, the model is loaded on the CPU, which is very slow
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
)
model_hf
.
eval
()
with
torch
.
no_grad
():
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
logits_hf
=
model_hf
(
input_ids
).
logits
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "test_baichuan_parallel"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
def
test_baichuan_parallel
(
model_name
,
world_size
):
"""Check that our implementation of Baichuan (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
config
=
baichuan_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
)
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
)
)
model
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
2
max_seqlen
=
256
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
device
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
with
torch
.
no_grad
():
out
=
model
.
transformer
(
input_ids
)
out
,
_
=
all_gather_raw
(
out
,
process_group
=
process_group
)
out
=
rearrange
(
out
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
=
model
(
input_ids
).
logits
logits
=
rearrange
(
logits
,
"(b s) d -> b s d"
,
b
=
batch_size
)
logits
,
_
=
all_gather_raw
(
logits
,
process_group
)
logits
=
rearrange
(
logits
,
"(n b) ... d -> b ... (n d)"
,
b
=
batch_size
)
del
model
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
model_ref
.
eval
()
with
torch
.
no_grad
():
out_ref
=
model_ref
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_ref
=
model_ref
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_ref
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
model_hf
.
eval
()
with
torch
.
no_grad
():
out_hf
=
model_hf
.
model
(
input_ids
).
last_hidden_state
.
to
(
device
=
device
)
logits_hf
=
model_hf
(
input_ids
).
logits
.
to
(
device
=
device
)
del
model_hf
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
def
test_baichuan_generation
(
model_name
):
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
device
=
"cuda"
config
=
baichuan_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
)
)
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
True
config
.
residual_in_fp32
=
True
tokenizer
=
AutoTokenizer
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
trust_remote_code
=
True
)
eos_token_id
=
tokenizer
.
eos_token_id
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
{
""
:
device
},
trust_remote_code
=
True
,
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_hf
# Need auto here since the 13B fp32 model doesn't fit in memory on a A100 40GB
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
)
model_ref
.
eval
()
with
torch
.
no_grad
():
logits_ref
=
(
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
].
to
(
device
=
device
)
)
del
model_ref
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
pretrained_state_dict
)
model
.
eval
()
print
(
"Without CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
eos_token_id
=
eos_token_id
,
fused_ft_kernel
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
fused_ft_kernel
=
True
,
cg
=
True
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
teacher_outputs
=
out_hf
.
sequences
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
with
torch
.
no_grad
():
logits_parallel
=
model
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
]
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
del
model
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
assert
torch
.
equal
(
logits_cg
,
logits
)
# torchrun --no_python --nproc_per_node=2 pytest -q -s tests/models/test_baichuan.py -k "baichuan_parallel_generation"
@
pytest
.
mark
.
parametrize
(
"world_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Baichuan-7B"
])
def
test_baichuan_parallel_generation
(
model_name
,
world_size
):
"""Check that our implementation matches the HF implementation:
the scores in fp16 should be around the same as the HF scores in fp16, when compared to
the HF scores in fp32.
"""
from
apex.transformer
import
parallel_state
checkpoint_path
=
Path
(
os
.
environ
.
get
(
"CHECKPOINT_DIR"
,
current_dir
.
parent
.
parent
/
"checkpoints"
)
)
dtype
=
torch
.
float16
config
=
baichuan_config_to_gpt2_config
(
config_from_checkpoint
(
checkpoint_path
,
model_name
)
)
config
.
use_flash_attn
=
False
config
.
fused_bias_fc
=
True
config
.
fused_mlp
=
False
# We don't have fused GatedMLP yet
config
.
fused_dropout_add_ln
=
False
config
.
residual_in_fp32
=
True
config
.
pad_vocab_size_multiple
=
8
*
world_size
config
.
sequence_parallel
=
False
# Need to set this to False for generation
os
.
environ
[
"NCCL_ASYNC_ERROR_HANDLING"
]
=
"0"
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
"nccl"
,
init_method
=
"env://"
)
device
=
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
process_group
=
parallel_state
.
get_tensor_model_parallel_group
()
torch
.
manual_seed
(
0
)
batch_size
=
1
seqlen
=
100
max_length
=
150
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
seqlen
),
dtype
=
torch
.
long
,
device
=
device
)
# Need this, otherwise when we capture the graph the process for GPU 1 would run on both
# GPU0 and GPU1 and things would hang
torch
.
cuda
.
set_device
(
device
)
ckpt_state_dicts
=
state_dicts_from_checkpoint
(
checkpoint_path
,
model_name
)
pretrained_state_dicts
=
[
remap_state_dict_hf_baichuan
(
s
,
config
)
for
s
in
ckpt_state_dicts
]
pretrained_state_dict
=
combine_state_dicts_tp
(
pretrained_state_dicts
,
config
)
model
=
GPTLMHeadModel
(
config
,
process_group
=
process_group
,
device
=
device
,
dtype
=
dtype
)
model
.
load_state_dict
(
shard_state_dict_tp
(
pretrained_state_dict
,
config
,
world_size
,
rank
)
)
model
.
eval
()
print
(
"Without CUDA graph"
)
out
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
)
# Capture graph outside the timing loop
batch_size
,
seqlen_og
=
input_ids
.
shape
model
.
_decoding_cache
=
update_graph_cache
(
model
,
None
,
batch_size
,
seqlen_og
,
max_length
)
print
(
"With CUDA graph"
)
out_cg
=
model
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
tensor_parallel
=
world_size
,
vocab_size
=
config
.
vocab_size
,
fused_ft_kernel
=
True
,
cg
=
True
,
# teacher_outputs=out_hf.sequences,
return_dict_in_generate
=
True
,
output_scores
=
True
,
timing
=
True
,
)
del
model
parallel_state
.
destroy_model_parallel
()
if
rank
==
0
:
# Without device_map, the model is loaded on the CPU, which is very slow
model_hf
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
torch_dtype
=
dtype
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
model_hf
.
eval
()
print
(
"HF fp16"
)
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
with
torch
.
inference_mode
():
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
max_length
=
max_length
,
return_dict_in_generate
=
True
,
output_scores
=
True
,
)
torch
.
cuda
.
synchronize
()
print
(
f
"Prompt processing + decoding time:
{
(
time
.
time
()
-
start
)
*
1000
:.
0
f
}
ms"
)
del
model_hf
model_ref
=
AutoModelForCausalLM
.
from_pretrained
(
Path
(
checkpoint_path
)
/
model_name
,
device_map
=
"auto"
,
trust_remote_code
=
True
,
)
model_ref
.
eval
()
with
torch
.
inference_mode
():
logits_ref
=
model_ref
(
out_hf
.
sequences
).
logits
[:,
(
seqlen
-
1
)
:
-
1
]
del
model_ref
logits_hf
=
torch
.
stack
(
out_hf
.
scores
,
dim
=
1
)
logits
=
torch
.
stack
(
out
.
scores
,
dim
=
1
)
logits_cg
=
torch
.
stack
(
out_cg
.
scores
,
dim
=
1
)
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
print
(
f
"HF fp16 logits max diff:
{
hf_error
}
"
)
print
(
f
"Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
"Logits CG max diff:
{
(
logits_cg
-
logits_ref
).
abs
().
max
().
item
()
}
"
)
assert
torch
.
equal
(
logits_cg
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment