Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
9d797d88
Commit
9d797d88
authored
Dec 27, 2022
by
Tri Dao
Browse files
Support loading GPT2 weights from Huggingface
parent
c6ecd40a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
239 additions
and
6 deletions
+239
-6
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+106
-6
tests/models/test_gpt.py
tests/models/test_gpt.py
+133
-0
No files found.
flash_attn/models/gpt.py
View file @
9d797d88
# Copyright (c) 2022, Tri Dao.
# Copyright (c) 2022, Tri Dao.
import
logging
import
math
import
math
import
re
from
functools
import
partial
from
functools
import
partial
from
collections
import
namedtuple
from
collections
import
namedtuple
,
OrderedDict
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
import
torch
import
torch
...
@@ -17,6 +19,7 @@ from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseG
...
@@ -17,6 +19,7 @@ from flash_attn.modules.mlp import Mlp, FusedDenseGeluDense, ParallelFusedDenseG
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.block
import
Block
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.modules.embedding
import
GPT2Embeddings
,
ParallelGPT2Embeddings
from
flash_attn.utils.distributed
import
sync_sequence_parallel_params
from
flash_attn.utils.distributed
import
sync_sequence_parallel_params
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
try
:
try
:
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
from
flash_attn.ops.fused_dense
import
ColumnParallelLinear
...
@@ -34,6 +37,9 @@ except ImportError:
...
@@ -34,6 +37,9 @@ except ImportError:
FusedDenseSqreluDense
=
None
FusedDenseSqreluDense
=
None
logger
=
logging
.
getLogger
(
__name__
)
def
create_mixer_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
create_mixer_cls
(
config
,
layer_idx
=
None
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
head_dim
=
getattr
(
config
,
'head_dim'
,
config
.
hidden_size
//
config
.
num_attention_heads
)
head_dim
=
getattr
(
config
,
'head_dim'
,
config
.
hidden_size
//
config
.
num_attention_heads
)
...
@@ -66,13 +72,20 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
...
@@ -66,13 +72,20 @@ def create_mlp_cls(config, layer_idx=None, process_group=None, device=None, dtyp
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
config
.
hidden_size
fused_dense_gelu_dense
=
getattr
(
config
,
'fused_dense_gelu_dense'
,
False
)
fused_dense_gelu_dense
=
getattr
(
config
,
'fused_dense_gelu_dense'
,
False
)
if
fused_dense_gelu_dense
:
assert
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
],
(
'fused_dense_gelu_dense only '
'supports approximate gelu'
)
fused_dense_sqrelu_dense
=
getattr
(
config
,
'fused_dense_sqrelu_dense'
,
False
)
fused_dense_sqrelu_dense
=
getattr
(
config
,
'fused_dense_sqrelu_dense'
,
False
)
if
fused_dense_sqrelu_dense
:
assert
config
.
activation_function
==
'sqrelu'
,
(
'fused_dense_sqrelu_dense only '
'supports approximate activation_function sqrelu'
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_dense_gelu_dense
)
assert
not
(
fused_dense_sqrelu_dense
and
fused_dense_gelu_dense
)
if
process_group
is
not
None
:
if
process_group
is
not
None
:
assert
fused_dense_gelu_dense
,
'Tensor Parallel is only implemented for FusedDenseGeluDense'
assert
fused_dense_gelu_dense
,
'Tensor Parallel is only implemented for FusedDenseGeluDense'
if
not
fused_dense_gelu_dense
and
not
fused_dense_sqrelu_dense
:
if
not
fused_dense_gelu_dense
and
not
fused_dense_sqrelu_dense
:
approximate
=
'tanh'
if
config
.
activation_function
in
[
'gelu_new'
,
'gelu_fast'
]
else
'none'
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
mlp_cls
=
partial
(
Mlp
,
hidden_features
=
inner_dim
,
activation
=
partial
(
F
.
gelu
,
approximate
=
'tanh'
),
**
factory_kwargs
)
activation
=
partial
(
F
.
gelu
,
approximate
=
approximate
),
**
factory_kwargs
)
else
:
else
:
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
mlp_checkpoint_lvl
=
getattr
(
config
,
'mlp_checkpoint_lvl'
,
0
)
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
# mlp_checkpoint_lvl could be a list, which contains the checkpoint_lvl for each layer
...
@@ -108,6 +121,34 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
...
@@ -108,6 +121,34 @@ def create_block(config, layer_idx=None, process_group=None, device=None, dtype=
return
block
return
block
class
GPTPreTrainedModel
(
nn
.
Module
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
().
__init__
()
if
not
isinstance
(
config
,
GPT2Config
):
raise
ValueError
(
"Parameter config in `{}(config)` should be an instance of class `GPT2Config`. "
"To create a model from a Google pretrained model use "
"`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
self
.
__class__
.
__name__
,
self
.
__class__
.
__name__
))
self
.
config
=
config
@
classmethod
def
from_pretrained
(
cls
,
model_name
,
config
,
*
inputs
,
**
kwargs
):
"""
Instantiate a GPTPreTrainedModel from a pre-trained model file or a pytorch state dict.
Download and cache the pre-trained model file if needed.
"""
# Instantiate model.
model
=
cls
(
config
,
*
inputs
,
**
kwargs
)
load_return
=
model
.
load_state_dict
(
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
),
config
))
logger
.
info
(
load_return
)
return
model
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
rescale_prenorm_residual
=
True
):
def
_init_weights
(
module
,
n_layer
,
initializer_range
=
0.02
,
rescale_prenorm_residual
=
True
):
if
isinstance
(
module
,
nn
.
Linear
):
if
isinstance
(
module
,
nn
.
Linear
):
...
@@ -130,12 +171,13 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
...
@@ -130,12 +171,13 @@ def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_resid
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
/
math
.
sqrt
(
2
*
n_layer
))
nn
.
init
.
normal_
(
p
,
mean
=
0.0
,
std
=
initializer_range
/
math
.
sqrt
(
2
*
n_layer
))
class
GPTModel
(
nn
.
Module
):
class
GPTModel
(
GPTPreTrainedModel
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
(
config
)
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
self
.
process_group
=
process_group
self
.
process_group
=
process_group
assert
config
.
activation_function
in
[
'gelu'
,
'gelu_new'
,
'gelu_fast'
,
'sqrelu'
]
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
self
.
pad_vocab_size_multiple
=
getattr
(
config
,
'pad_vocab_size_multiple'
,
1
)
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
if
config
.
vocab_size
%
self
.
pad_vocab_size_multiple
!=
0
:
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
config
.
vocab_size
+=
(
self
.
pad_vocab_size_multiple
...
@@ -201,11 +243,11 @@ class GPTModel(nn.Module):
...
@@ -201,11 +243,11 @@ class GPTModel(nn.Module):
return
hidden_states
return
hidden_states
class
GPTLMHeadModel
(
nn
.
Module
):
class
GPTLMHeadModel
(
GPTPreTrainedModel
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
def
__init__
(
self
,
config
:
GPT2Config
,
process_group
=
None
,
device
=
None
,
dtype
=
None
):
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
super
().
__init__
(
config
)
self
.
process_group
=
process_group
self
.
process_group
=
process_group
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
self
.
transformer
=
GPTModel
(
config
,
process_group
=
process_group
,
**
factory_kwargs
)
if
process_group
is
None
:
if
process_group
is
None
:
...
@@ -230,3 +272,61 @@ class GPTLMHeadModel(nn.Module):
...
@@ -230,3 +272,61 @@ class GPTLMHeadModel(nn.Module):
lm_logits
=
self
.
lm_head
(
hidden_states
)
lm_logits
=
self
.
lm_head
(
hidden_states
)
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
CausalLMOutput
=
namedtuple
(
'CausalLMOutput'
,
[
'logits'
])
return
CausalLMOutput
(
logits
=
lm_logits
)
return
CausalLMOutput
(
logits
=
lm_logits
)
def
remap_state_dict_gpt2
(
state_dict
,
config
):
# Word embedding and position embedding
def
key_mapping_pos_emb
(
key
):
return
re
.
sub
(
r
'^wpe.'
,
'transformer.embeddings.position_embeddings.'
,
key
)
state_dict
=
OrderedDict
((
key_mapping_pos_emb
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
word_embeddings
=
state_dict
.
pop
(
'wte.weight'
)
# It's possible that vocab_size is padded to be a multiple of 8, for example.
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
=
F
.
pad
(
word_embeddings
,
(
0
,
0
,
0
,
config
.
vocab_size
-
word_embeddings
.
shape
[
0
])
)
state_dict
[
'lm_head.weight'
]
=
state_dict
[
'transformer.embeddings.word_embeddings.weight'
]
# LayerNorm
ln_weight
,
ln_bias
=
state_dict
.
pop
(
'ln_f.weight'
),
state_dict
.
pop
(
'ln_f.bias'
)
state_dict
[
f
'transformer.layers.
{
config
.
num_hidden_layers
-
1
}
.norm2.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
config
.
num_hidden_layers
-
1
}
.norm2.bias'
]
=
ln_bias
ln_weight
,
ln_bias
=
state_dict
.
pop
(
'h.0.ln_1.weight'
),
state_dict
.
pop
(
'h.0.ln_1.bias'
)
state_dict
[
'transformer.ln_0.weight'
]
=
ln_weight
state_dict
[
'transformer.ln_0.bias'
]
=
ln_bias
for
d
in
range
(
config
.
num_hidden_layers
):
ln_weight
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_2.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_2.bias'
)
state_dict
[
f
'transformer.layers.
{
d
}
.norm1.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
d
}
.norm1.bias'
]
=
ln_bias
if
d
>
0
:
ln_weight
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_1.weight'
)
ln_bias
=
state_dict
.
pop
(
f
'h.
{
d
}
.ln_1.bias'
)
state_dict
[
f
'transformer.layers.
{
d
-
1
}
.norm2.weight'
]
=
ln_weight
state_dict
[
f
'transformer.layers.
{
d
-
1
}
.norm2.bias'
]
=
ln_bias
# MLP
for
d
in
range
(
config
.
num_hidden_layers
):
W1
=
state_dict
.
pop
(
f
'h.
{
d
}
.mlp.c_fc.weight'
)
state_dict
[
f
'transformer.layers.
{
d
}
.mlp.fc1.weight'
]
=
W1
.
t
()
W2
=
state_dict
.
pop
(
f
'h.
{
d
}
.mlp.c_proj.weight'
)
state_dict
[
f
'transformer.layers.
{
d
}
.mlp.fc2.weight'
]
=
W2
.
t
()
def
key_mapping_mlp
(
key
):
key
=
re
.
sub
(
r
'^h.(\d+).mlp.c_fc.bias'
,
r
'transformer.layers.\1.mlp.fc1.bias'
,
key
)
key
=
re
.
sub
(
r
'^h.(\d+).mlp.c_proj.bias'
,
r
'transformer.layers.\1.mlp.fc2.bias'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_mlp
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
# Attention
for
d
in
range
(
config
.
num_hidden_layers
):
state_dict
.
pop
(
f
'h.
{
d
}
.attn.bias'
)
# We don't store this bias
Wqkv
=
state_dict
.
pop
(
f
'h.
{
d
}
.attn.c_attn.weight'
)
state_dict
[
f
'transformer.layers.
{
d
}
.mixer.Wqkv.weight'
]
=
Wqkv
.
t
()
Wout
=
state_dict
.
pop
(
f
'h.
{
d
}
.attn.c_proj.weight'
)
state_dict
[
f
'transformer.layers.
{
d
}
.mixer.out_proj.weight'
]
=
Wout
.
t
()
def
key_mapping_attn
(
key
):
key
=
re
.
sub
(
r
'^h.(\d+).attn.c_attn.bias'
,
r
'transformer.layers.\1.mixer.Wqkv.bias'
,
key
)
key
=
re
.
sub
(
r
'^h.(\d+).attn.c_proj.bias'
,
r
'transformer.layers.\1.mixer.out_proj.bias'
,
key
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
return
state_dict
tests/models/test_gpt.py
0 → 100644
View file @
9d797d88
import
re
import
torch
import
pytest
from
transformers
import
GPT2Config
from
transformers.models.gpt2.modeling_gpt2
import
GPT2LMHeadModel
as
GPT2LMHeadModelHF
from
flash_attn.models.gpt
import
GPTLMHeadModel
from
flash_attn.models.gpt
import
remap_state_dict_gpt2
from
flash_attn.utils.pretrained
import
state_dict_from_pretrained
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
,
"gpt2-medium"
])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def
test_gpt2_state_dict
(
model_name
):
config
=
GPT2Config
.
from_pretrained
(
model_name
)
pretrained_state_dict
=
remap_state_dict_gpt2
(
state_dict_from_pretrained
(
model_name
),
config
)
model
=
GPTLMHeadModel
(
config
)
state_dict
=
model
.
state_dict
()
assert
state_dict
.
keys
()
==
pretrained_state_dict
.
keys
()
for
k
in
state_dict
.
keys
():
assert
state_dict
[
k
].
shape
==
pretrained_state_dict
[
k
].
shape
def
get_hf_models
(
model_name
,
config
,
dtype
):
pretrained_state_dict
=
state_dict_from_pretrained
(
model_name
)
model_hf
=
GPT2LMHeadModelHF
(
config
)
# Missing key(s) in state_dict: "bert.embeddings.position_ids", "cls.predictions.decoder.bias"
# position_ids is a buffer, and predictions.decoder.bias is tied to predictions.bias.
model_hf
.
load_state_dict
(
pretrained_state_dict
,
strict
=
False
)
model_hf
.
cuda
().
to
(
dtype
=
dtype
)
return
model_hf
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
,
"gpt2-medium"
])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def
test_gpt2_non_optimized
(
model_name
):
"""Check that our implementation of GPT2 (without any optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
config
=
GPT2Config
.
from_pretrained
(
model_name
)
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
4
max_seqlen
=
512
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
input_ids
=
torch
.
randint
(
0
,
config
.
vocab_size
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
out
=
model
.
transformer
(
input_ids
)
out_hf
=
model_hf
.
transformer
(
input_ids
).
last_hidden_state
out_ref
=
model_ref
.
transformer
(
input_ids
).
last_hidden_state
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
logits
=
model
(
input_ids
).
logits
logits_hf
=
model_hf
(
input_ids
).
logits
logits_ref
=
model_ref
(
input_ids
).
logits
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"gpt2"
,
"gpt2-medium"
])
# @pytest.mark.parametrize('model_name', ["gpt2"])
def
test_gpt2_optimized
(
model_name
):
"""Check that our implementation of GPT2 (with all optimizations enabled) matches the
HF implementation: the output of our forward pass in fp16 should be around the same as the HF
forward pass in fp16, when compared to the HF forward pass in fp32.
"""
dtype
=
torch
.
float16
config
=
GPT2Config
.
from_pretrained
(
model_name
)
vocab_size_og
=
config
.
vocab_size
config
.
use_flash_attn
=
True
config
.
fused_bias_fc
=
True
config
.
fused_dense_gelu_dense
=
True
config
.
fused_dropout_add_ln
=
True
config
.
pad_vocab_size_multiple
=
8
model
=
GPTLMHeadModel
.
from_pretrained
(
model_name
,
config
)
model
=
model
.
cuda
().
to
(
dtype
=
dtype
)
model_ref
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
()
model_hf
=
GPT2LMHeadModelHF
.
from_pretrained
(
model_name
).
cuda
().
to
(
dtype
=
dtype
)
model
.
eval
()
model_ref
.
eval
()
model_hf
.
eval
()
torch
.
manual_seed
(
0
)
batch_size
=
4
max_seqlen
=
512
seqlens
=
torch
.
randint
(
max_seqlen
//
2
,
max_seqlen
+
1
,
(
batch_size
,),
device
=
'cuda'
)
input_ids
=
torch
.
randint
(
0
,
vocab_size_og
,
(
batch_size
,
max_seqlen
),
dtype
=
torch
.
long
,
device
=
'cuda'
)
out
=
model
.
transformer
(
input_ids
)
out_hf
=
model_hf
.
transformer
(
input_ids
).
last_hidden_state
out_ref
=
model_ref
.
transformer
(
input_ids
).
last_hidden_state
print
(
f
'Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
logits
=
model
(
input_ids
).
logits
[...,
:
vocab_size_og
]
logits_hf
=
model_hf
(
input_ids
).
logits
logits_ref
=
model_ref
(
input_ids
).
logits
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment