Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
de627f5a
Unverified
Commit
de627f5a
authored
Mar 21, 2024
by
Matt
Committed by
GitHub
Mar 21, 2024
Browse files
Cast bfloat16 to float32 for Numpy conversions (#29755)
* Cast bfloat16 to float32 for Numpy conversions * Add test
parent
73a73b41
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
15 additions
and
1 deletion
+15
-1
src/transformers/modeling_tf_pytorch_utils.py
src/transformers/modeling_tf_pytorch_utils.py
+4
-1
tests/test_modeling_tf_utils.py
tests/test_modeling_tf_utils.py
+11
-0
No files found.
src/transformers/modeling_tf_pytorch_utils.py
View file @
de627f5a
...
...
@@ -249,7 +249,10 @@ def load_pytorch_weights_in_tf2_model(
)
raise
pt_state_dict
=
{
k
:
v
.
numpy
()
for
k
,
v
in
pt_state_dict
.
items
()}
# Numpy doesn't understand bfloat16, so upcast to a dtype that doesn't lose precision
pt_state_dict
=
{
k
:
v
.
numpy
()
if
v
.
dtype
!=
torch
.
bfloat16
else
v
.
float
().
numpy
()
for
k
,
v
in
pt_state_dict
.
items
()
}
return
load_pytorch_state_dict_in_tf2_model
(
tf_model
,
pt_state_dict
,
...
...
tests/test_modeling_tf_utils.py
View file @
de627f5a
...
...
@@ -63,6 +63,7 @@ if is_tf_available():
PreTrainedModel
,
PushToHubCallback
,
RagRetriever
,
TFAutoModel
,
TFBertForMaskedLM
,
TFBertForSequenceClassification
,
TFBertModel
,
...
...
@@ -435,6 +436,16 @@ class TFModelUtilsTest(unittest.TestCase):
for
p1
,
p2
in
zip
(
model
.
weights
,
new_model
.
weights
):
self
.
assertTrue
(
np
.
allclose
(
p1
.
numpy
(),
p2
.
numpy
()))
@
is_pt_tf_cross_test
@
require_safetensors
def
test_bfloat16_torch_loading
(
self
):
# Assert that neither of these raise an error - both repos contain bfloat16 tensors
model1
=
TFAutoModel
.
from_pretrained
(
"Rocketknight1/tiny-random-gpt2-bfloat16-pt"
,
from_pt
=
True
)
model2
=
TFAutoModel
.
from_pretrained
(
"Rocketknight1/tiny-random-gpt2-bfloat16"
)
# PT-format safetensors
# Check that PT and safetensors loading paths end up with the same values
for
weight1
,
weight2
in
zip
(
model1
.
weights
,
model2
.
weights
):
self
.
assertTrue
(
tf
.
reduce_all
(
weight1
==
weight2
))
@
slow
def
test_save_pretrained_signatures
(
self
):
model
=
TFBertModel
.
from_pretrained
(
"hf-internal-testing/tiny-random-bert"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment