Unverified Commit 008a6a22 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Enable safetensors conversion from PyTorch to other frameworks without the...


Enable safetensors conversion from PyTorch to other frameworks without the torch requirement (#27599)

* Initial commit

* Requirements & tests

* Tests

* Tests

* Rogue import

* Rogue torch import

* Cleanup

* Apply suggestions from code review
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>

* bfloat16 management

* Sanchit's comments

* Import shield

* apply suggestions from code review

* correct bf16

* rebase

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: default avatarsanchit-gandhi <sanchit@huggingface.co>
parent 03986609
...@@ -158,7 +158,7 @@ _deps = [ ...@@ -158,7 +158,7 @@ _deps = [
"ruff==0.1.5", "ruff==0.1.5",
"sacrebleu>=1.4.12,<2.0.0", "sacrebleu>=1.4.12,<2.0.0",
"sacremoses", "sacremoses",
"safetensors>=0.3.1", "safetensors>=0.4.1",
"sagemaker>=2.31.0", "sagemaker>=2.31.0",
"scikit-learn", "scikit-learn",
"sentencepiece>=0.1.91,!=0.1.92", "sentencepiece>=0.1.91,!=0.1.92",
......
...@@ -64,7 +64,7 @@ deps = { ...@@ -64,7 +64,7 @@ deps = {
"ruff": "ruff==0.1.5", "ruff": "ruff==0.1.5",
"sacrebleu": "sacrebleu>=1.4.12,<2.0.0", "sacrebleu": "sacrebleu>=1.4.12,<2.0.0",
"sacremoses": "sacremoses", "sacremoses": "sacremoses",
"safetensors": "safetensors>=0.3.1", "safetensors": "safetensors>=0.4.1",
"sagemaker": "sagemaker>=2.31.0", "sagemaker": "sagemaker>=2.31.0",
"scikit-learn": "scikit-learn", "scikit-learn": "scikit-learn",
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92", "sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
......
...@@ -27,10 +27,13 @@ from flax.traverse_util import flatten_dict, unflatten_dict ...@@ -27,10 +27,13 @@ from flax.traverse_util import flatten_dict, unflatten_dict
import transformers import transformers
from . import is_safetensors_available from . import is_safetensors_available, is_torch_available
from .utils import logging from .utils import logging
if is_torch_available():
import torch
if is_safetensors_available(): if is_safetensors_available():
from safetensors import safe_open from safetensors import safe_open
from safetensors.flax import load_file as safe_load_file from safetensors.flax import load_file as safe_load_file
...@@ -48,6 +51,17 @@ def load_pytorch_checkpoint_in_flax_state_dict( ...@@ -48,6 +51,17 @@ def load_pytorch_checkpoint_in_flax_state_dict(
flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False flax_model, pytorch_checkpoint_path, is_sharded, allow_missing_keys=False
): ):
"""Load pytorch checkpoints in a flax model""" """Load pytorch checkpoints in a flax model"""
if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
if pt_path.endswith(".safetensors"):
pt_state_dict = {}
with safe_open(pt_path, framework="flax") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
try: try:
import torch # noqa: F401 import torch # noqa: F401
...@@ -60,16 +74,6 @@ def load_pytorch_checkpoint_in_flax_state_dict( ...@@ -60,16 +74,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
) )
raise raise
if not is_sharded:
pt_path = os.path.abspath(pytorch_checkpoint_path)
logger.info(f"Loading PyTorch weights from {pt_path}")
if pt_path.endswith(".safetensors"):
pt_state_dict = {}
with safe_open(pt_path, framework="pt") as f:
for k in f.keys():
pt_state_dict[k] = f.get_tensor(k)
else:
pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13) pt_state_dict = torch.load(pt_path, map_location="cpu", weights_only=is_torch_greater_or_equal_than_1_13)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.") logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")
...@@ -149,21 +153,17 @@ def rename_key_and_reshape_tensor( ...@@ -149,21 +153,17 @@ def rename_key_and_reshape_tensor(
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# convert pytorch tensor to numpy # convert pytorch tensor to numpy
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision from_bin = is_torch_available() and isinstance(next(iter(pt_state_dict.values())), torch.Tensor)
try: bfloat16 = torch.bfloat16 if from_bin else "bfloat16"
import torch # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
" https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for installation"
" instructions."
)
raise
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()} weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
k: v.numpy() if not v.dtype == torch.bfloat16 else v.float().numpy() for k, v in pt_state_dict.items() if from_bin:
} for k, v in pt_state_dict.items():
# numpy currently does not support bfloat16, need to go over float32 in this case to not lose precision
if v.dtype == bfloat16:
v = v.float()
pt_state_dict[k] = v.numpy()
model_prefix = flax_model.base_model_prefix model_prefix = flax_model.base_model_prefix
...@@ -191,7 +191,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -191,7 +191,7 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
# Need to change some parameters name to match Flax names # Need to change some parameters name to match Flax names
for pt_key, pt_tensor in pt_state_dict.items(): for pt_key, pt_tensor in pt_state_dict.items():
pt_tuple_key = tuple(pt_key.split(".")) pt_tuple_key = tuple(pt_key.split("."))
is_bfloat_16 = weight_dtypes[pt_key] == torch.bfloat16 is_bfloat_16 = weight_dtypes[pt_key] == bfloat16
# remove base model prefix if necessary # remove base model prefix if necessary
has_base_model_prefix = pt_tuple_key[0] == model_prefix has_base_model_prefix = pt_tuple_key[0] == model_prefix
...@@ -229,7 +229,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model): ...@@ -229,7 +229,6 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
flax_state_dict[("params",) + flax_key] = ( flax_state_dict[("params",) + flax_key] = (
jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16) jnp.asarray(flax_tensor) if not is_bfloat_16 else jnp.asarray(flax_tensor, dtype=jnp.bfloat16)
) )
else: else:
# also add unexpected weight so that warning is thrown # also add unexpected weight so that warning is thrown
flax_state_dict[flax_key] = ( flax_state_dict[flax_key] = (
......
...@@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai ...@@ -23,13 +23,14 @@ from transformers import BertConfig, BertModel, is_flax_available, is_torch_avai
from transformers.testing_utils import ( from transformers.testing_utils import (
TOKEN, TOKEN,
USER, USER,
CaptureLogger,
is_pt_flax_cross_test, is_pt_flax_cross_test,
is_staging_test, is_staging_test,
require_flax, require_flax,
require_safetensors, require_safetensors,
require_torch, require_torch,
) )
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
if is_flax_available(): if is_flax_available():
...@@ -42,6 +43,9 @@ if is_flax_available(): ...@@ -42,6 +43,9 @@ if is_flax_available():
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
if is_torch_available():
import torch
@require_flax @require_flax
@is_staging_test @is_staging_test
...@@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase): ...@@ -251,7 +255,6 @@ class FlaxModelUtilsTest(unittest.TestCase):
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors @require_safetensors
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_safetensors_load_from_hub_from_safetensors_pt(self): def test_safetensors_load_from_hub_from_safetensors_pt(self):
...@@ -265,57 +268,44 @@ class FlaxModelUtilsTest(unittest.TestCase): ...@@ -265,57 +268,44 @@ class FlaxModelUtilsTest(unittest.TestCase):
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors") safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_torch
@require_safetensors @require_safetensors
@require_torch
@is_pt_flax_cross_test @is_pt_flax_cross_test
def test_safetensors_load_from_local_from_safetensors_pt(self): def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
""" """
This test checks that we can load safetensors from a checkpoint that only has those on the Hub. This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format. saved in the "pt" format.
""" """
with tempfile.TemporaryDirectory() as tmp: import torch
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
flax_model = FlaxBertModel.from_pretrained(location) model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
model.to(torch.bfloat16)
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp) model.save_pretrained(tmp)
safetensors_model = FlaxBertModel.from_pretrained(location) flax_model = FlaxBertModel.from_pretrained(tmp)
# Can load from the PyTorch-formatted checkpoint
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(check_models_equal(flax_model, safetensors_model)) self.assertTrue(check_models_equal(flax_model, safetensors_model))
@require_safetensors @require_safetensors
def test_safetensors_load_from_hub_from_safetensors_pt_without_torch_installed(self): @is_pt_flax_cross_test
""" def test_safetensors_load_from_local_from_safetensors_pt(self):
This test checks that we cannot load safetensors from a checkpoint that only has safetensors
saved in the "pt" format if torch isn't installed.
"""
if is_torch_available():
# This test verifies that a correct error message is shown when loading from a pt safetensors
# PyTorch shouldn't be installed for this to work correctly.
return
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
@require_safetensors
def test_safetensors_load_from_local_from_safetensors_pt_without_torch_installed(self):
""" """
This test checks that we cannot load safetensors from a checkpoint that only has safetensors This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
saved in the "pt" format if torch isn't installed. saved in the "pt" format.
""" """
if is_torch_available(): with tempfile.TemporaryDirectory() as tmp:
# This test verifies that a correct error message is shown when loading from a pt safetensors location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
# PyTorch shouldn't be installed for this to work correctly. flax_model = FlaxBertModel.from_pretrained(location)
return
# Can load from the PyTorch-formatted checkpoint
with tempfile.TemporaryDirectory() as tmp: with tempfile.TemporaryDirectory() as tmp:
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp) location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
safetensors_model = FlaxBertModel.from_pretrained(location)
# Cannot load from the PyTorch-formatted checkpoint without PyTorch installed self.assertTrue(check_models_equal(flax_model, safetensors_model))
with self.assertRaises(ModuleNotFoundError):
_ = FlaxBertModel.from_pretrained(location)
@require_safetensors @require_safetensors
def test_safetensors_load_from_hub_msgpack_before_safetensors(self): def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
...@@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase): ...@@ -347,6 +337,7 @@ class FlaxModelUtilsTest(unittest.TestCase):
@require_safetensors @require_safetensors
@require_torch @require_torch
@is_pt_flax_cross_test
def test_safetensors_flax_from_torch(self): def test_safetensors_flax_from_torch(self):
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only") hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only") model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
...@@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase): ...@@ -372,3 +363,41 @@ class FlaxModelUtilsTest(unittest.TestCase):
# This should not raise even if there are two types of sharded weights # This should not raise even if there are two types of sharded weights
# This should discard the safetensors weights in favor of the msgpack sharded weights # This should discard the safetensors weights in favor of the msgpack sharded weights
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded") FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
@require_safetensors
def test_safetensors_from_pt_bf16(self):
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
# and without torch.
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
@require_torch
@require_safetensors
@is_pt_flax_cross_test
def test_from_pt_bf16(self):
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
model.to(torch.bfloat16)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir, safe_serialization=False)
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
self.assertTrue(
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
in cl.out
)
flat_params_1 = flatten_dict(new_model.params)
for value in flat_params_1.values():
self.assertEqual(value.dtype, "bfloat16")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment