Unverified Commit d4eb52d1 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Refactor conversion function (#19799)

* Refactor conversion function

* Remove dupe line

* Fixes

* Fixes

* Use the right variable...

* Fix last test
parent 9ecb13d6
......@@ -21,7 +21,8 @@ import re
import numpy
from .utils import ExplicitEnum, logging
from .utils import ExplicitEnum, expand_dims, is_numpy_array, is_torch_tensor, logging, reshape, squeeze
from .utils import transpose as transpose_func
logger = logging.get_logger(__name__)
......@@ -66,10 +67,12 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
if len(tf_name) > 1:
tf_name = tf_name[1:] # Remove level zero
tf_weight_shape = list(tf_weight_shape)
# When should we transpose the weights
if tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 4:
if tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 4:
transpose = TransposeType.CONV2D
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and tf_weight_shape.rank == 3:
elif tf_name[-1] == "kernel" and tf_weight_shape is not None and len(tf_weight_shape) == 3:
transpose = TransposeType.CONV1D
elif bool(
tf_name[-1] in ["kernel", "pointwise_kernel", "depthwise_kernel"]
......@@ -98,6 +101,43 @@ def convert_tf_weight_name_to_pt_weight_name(tf_name, start_prefix_to_remove="",
return tf_name, transpose
def apply_transpose(transpose: TransposeType, weight, match_shape=None, pt_to_tf=True):
"""
Apply a transpose to some weight then tries to reshape the weight to the same shape as a given shape, all in a
framework agnostic way.
"""
if transpose is TransposeType.CONV2D:
# Conv2D weight:
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
axes = (2, 3, 1, 0) if pt_to_tf else (3, 2, 0, 1)
weight = transpose_func(weight, axes=axes)
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# PT: (num_out_channel, num_in_channel, kernel)
# -> TF: (kernel, num_in_channel, num_out_channel)
weight = transpose_func(weight, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
weight = transpose_func(weight)
if match_shape is None:
return weight
if len(match_shape) < len(weight.shape):
weight = squeeze(weight)
elif len(match_shape) > len(weight.shape):
weight = expand_dims(weight, axis=0)
if list(match_shape) != list(weight.shape):
try:
weight = reshape(weight, match_shape)
except AssertionError as e:
e.args += (match_shape, match_shape)
raise e
return weight
#####################
# PyTorch => TF 2.0 #
#####################
......@@ -155,7 +195,6 @@ def load_pytorch_weights_in_tf2_model(
try:
import tensorflow as tf # noqa: F401
import torch # noqa: F401
from tensorflow.python.keras import backend as K
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
......@@ -163,6 +202,22 @@ def load_pytorch_weights_in_tf2_model(
)
raise
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
return load_pytorch_state_dict_in_tf2_model(
tf_model,
pt_state_dict,
tf_inputs=tf_inputs,
allow_missing_keys=allow_missing_keys,
output_loading_info=output_loading_info,
)
def load_pytorch_state_dict_in_tf2_model(
tf_model, pt_state_dict, tf_inputs=None, allow_missing_keys=False, output_loading_info=False
):
"""Load a pytorch state_dict in a TF 2.0 model."""
from tensorflow.python.keras import backend as K
if tf_inputs is None:
tf_inputs = tf_model.dummy_inputs
......@@ -216,41 +271,9 @@ def load_pytorch_weights_in_tf2_model(
continue
raise AttributeError(f"{name} not found in PyTorch model")
array = pt_state_dict[name].numpy()
if transpose is TransposeType.CONV2D:
# Conv2D weight:
# PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
# -> TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 3, 1, 0))
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# PT: (num_out_channel, num_in_channel, kernel)
# -> TF: (kernel, num_in_channel, num_out_channel)
array = numpy.transpose(array, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array)
if len(symbolic_weight.shape) < len(array.shape):
array = numpy.squeeze(array)
elif len(symbolic_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)
if list(symbolic_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, symbolic_weight.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (symbolic_weight.shape, array.shape)
raise e
array = apply_transpose(transpose, pt_state_dict[name], symbolic_weight.shape)
tf_loaded_numel += array.size
# logger.warning(f"Initialize TF weight {symbolic_weight.name}")
weight_value_tuples.append((symbolic_weight, array))
all_pytorch_weights.discard(name)
......@@ -370,6 +393,15 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
)
raise
tf_state_dict = {tf_weight.name: tf_weight.numpy() for tf_weight in tf_weights}
return load_tf2_state_dict_in_pytorch_model(
pt_model, tf_state_dict, allow_missing_keys=allow_missing_keys, output_loading_info=output_loading_info
)
def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_keys=False, output_loading_info=False):
import torch
new_pt_params_dict = {}
current_pt_params_dict = dict(pt_model.named_parameters())
......@@ -381,11 +413,11 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
# Build a map from potential PyTorch weight names to TF 2.0 Variables
tf_weights_map = {}
for tf_weight in tf_weights:
for name, tf_weight in tf_state_dict.items():
pt_name, transpose = convert_tf_weight_name_to_pt_weight_name(
tf_weight.name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
name, start_prefix_to_remove=start_prefix_to_remove, tf_weight_shape=tf_weight.shape
)
tf_weights_map[pt_name] = (tf_weight.numpy(), transpose)
tf_weights_map[pt_name] = (tf_weight, transpose)
all_tf_weights = set(list(tf_weights_map.keys()))
loaded_pt_weights_data_ptr = {}
......@@ -406,43 +438,18 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
array, transpose = tf_weights_map[pt_weight_name]
if transpose is TransposeType.CONV2D:
# Conv2D weight:
# TF: (kernel[0], kernel[1], num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel[0], kernel[1])
array = numpy.transpose(array, axes=(3, 2, 0, 1))
elif transpose is TransposeType.CONV1D:
# Conv1D weight:
# TF: (kernel, num_in_channel, num_out_channel)
# -> PT: (num_out_channel, num_in_channel, kernel)
array = numpy.transpose(array, axes=(2, 1, 0))
elif transpose is TransposeType.SIMPLE:
array = numpy.transpose(array)
if len(pt_weight.shape) < len(array.shape):
array = numpy.squeeze(array)
elif len(pt_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)
if list(pt_weight.shape) != list(array.shape):
try:
array = numpy.reshape(array, pt_weight.shape)
except AssertionError as e:
e.args += (pt_weight.shape, array.shape)
raise e
try:
assert list(pt_weight.shape) == list(array.shape)
except AssertionError as e:
e.args += (pt_weight.shape, array.shape)
raise e
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
# logger.warning(f"Initialize PyTorch weight {pt_weight_name}")
# Make sure we have a proper numpy array
if numpy.isscalar(array):
array = numpy.array(array)
new_pt_params_dict[pt_weight_name] = torch.from_numpy(array)
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = torch.from_numpy(array)
if not is_torch_tensor(array) and not is_numpy_array(array):
array = array.numpy()
if is_numpy_array(array):
# Convert to torch tensor
array = torch.from_numpy(array)
new_pt_params_dict[pt_weight_name] = array
loaded_pt_weights_data_ptr[pt_weight.data_ptr()] = array
all_tf_weights.discard(pt_weight_name)
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
......
......@@ -38,6 +38,7 @@ from .generic import (
PaddingStrategy,
TensorType,
cached_property,
expand_dims,
find_labels,
flatten_dict,
is_jax_tensor,
......@@ -46,8 +47,11 @@ from .generic import (
is_tf_tensor,
is_torch_device,
is_torch_tensor,
reshape,
squeeze,
to_numpy,
to_py_obj,
transpose,
working_or_temp_dir,
)
from .hub import (
......
......@@ -29,6 +29,13 @@ import numpy as np
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp
class cached_property(property):
"""
Descriptor that mimics @property but caches output in member variable.
......@@ -370,3 +377,71 @@ def working_or_temp_dir(working_dir, use_temp_dir: bool = False):
yield tmp_dir
else:
yield working_dir
def transpose(array, axes=None):
"""
Framework-agnostic version of `numpy.transpose` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.transpose(array, axes=axes)
elif is_torch_tensor(array):
return array.T if axes is None else array.permute(*axes)
elif is_tf_tensor(array):
return tf.transpose(array, perm=axes)
elif is_jax_tensor(array):
return jnp.transpose(array, axes=axes)
else:
raise ValueError(f"Type not supported for transpose: {type(array)}.")
def reshape(array, newshape):
"""
Framework-agnostic version of `numpy.reshape` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.reshape(array, newshape)
elif is_torch_tensor(array):
return array.reshape(*newshape)
elif is_tf_tensor(array):
return tf.reshape(array, newshape)
elif is_jax_tensor(array):
return jnp.reshape(array, newshape)
else:
raise ValueError(f"Type not supported for reshape: {type(array)}.")
def squeeze(array, axis=None):
"""
Framework-agnostic version of `numpy.squeeze` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.squeeze(array, axis=axis)
elif is_torch_tensor(array):
return array.squeeze() if axis is None else array.squeeze(dim=axis)
elif is_tf_tensor(array):
return tf.squeeze(array, axis=axis)
elif is_jax_tensor(array):
return jnp.squeeze(array, axis=axis)
else:
raise ValueError(f"Type not supported for squeeze: {type(array)}.")
def expand_dims(array, axis):
"""
Framework-agnostic version of `numpy.expand_dims` that will work on torch/TensorFlow/Jax tensors as well as NumPy
arrays.
"""
if is_numpy_array(array):
return np.expand_dims(array, axis)
elif is_torch_tensor(array):
return array.unsqueeze(dim=axis)
elif is_tf_tensor(array):
return tf.expand_dims(array, axis=axis)
elif is_jax_tensor(array):
return jnp.expand_dims(array, axis=axis)
else:
raise ValueError(f"Type not supported for expand_dims: {type(array)}.")
......@@ -15,7 +15,29 @@
import unittest
from transformers.utils import flatten_dict
import numpy as np
from transformers.testing_utils import require_flax, require_tf, require_torch
from transformers.utils import (
expand_dims,
flatten_dict,
is_flax_available,
is_tf_available,
is_torch_available,
reshape,
squeeze,
transpose,
)
if is_flax_available():
import jax.numpy as jnp
if is_tf_available():
import tensorflow as tf
if is_torch_available():
import torch
class GenericTester(unittest.TestCase):
......@@ -43,3 +65,136 @@ class GenericTester(unittest.TestCase):
}
self.assertEqual(flatten_dict(input_dict), expected_dict)
def test_transpose_numpy(self):
x = np.random.randn(3, 4)
self.assertTrue(np.allclose(transpose(x), x.transpose()))
x = np.random.randn(3, 4, 5)
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), x.transpose((1, 2, 0))))
@require_torch
def test_transpose_torch(self):
x = np.random.randn(3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
x = np.random.randn(3, 4, 5)
t = torch.tensor(x)
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
@require_tf
def test_transpose_tf(self):
x = np.random.randn(3, 4)
t = tf.constant(x)
self.assertTrue(np.allclose(transpose(x), transpose(t).numpy()))
x = np.random.randn(3, 4, 5)
t = tf.constant(x)
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), transpose(t, axes=(1, 2, 0)).numpy()))
@require_flax
def test_transpose_flax(self):
x = np.random.randn(3, 4)
t = jnp.array(x)
self.assertTrue(np.allclose(transpose(x), np.asarray(transpose(t))))
x = np.random.randn(3, 4, 5)
t = jnp.array(x)
self.assertTrue(np.allclose(transpose(x, axes=(1, 2, 0)), np.asarray(transpose(t, axes=(1, 2, 0)))))
def test_reshape_numpy(self):
x = np.random.randn(3, 4)
self.assertTrue(np.allclose(reshape(x, (4, 3)), np.reshape(x, (4, 3))))
x = np.random.randn(3, 4, 5)
self.assertTrue(np.allclose(reshape(x, (12, 5)), np.reshape(x, (12, 5))))
@require_torch
def test_reshape_torch(self):
x = np.random.randn(3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
x = np.random.randn(3, 4, 5)
t = torch.tensor(x)
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
@require_tf
def test_reshape_tf(self):
x = np.random.randn(3, 4)
t = tf.constant(x)
self.assertTrue(np.allclose(reshape(x, (4, 3)), reshape(t, (4, 3)).numpy()))
x = np.random.randn(3, 4, 5)
t = tf.constant(x)
self.assertTrue(np.allclose(reshape(x, (12, 5)), reshape(t, (12, 5)).numpy()))
@require_flax
def test_reshape_flax(self):
x = np.random.randn(3, 4)
t = jnp.array(x)
self.assertTrue(np.allclose(reshape(x, (4, 3)), np.asarray(reshape(t, (4, 3)))))
x = np.random.randn(3, 4, 5)
t = jnp.array(x)
self.assertTrue(np.allclose(reshape(x, (12, 5)), np.asarray(reshape(t, (12, 5)))))
def test_squeeze_numpy(self):
x = np.random.randn(1, 3, 4)
self.assertTrue(np.allclose(squeeze(x), np.squeeze(x)))
x = np.random.randn(1, 4, 1, 5)
self.assertTrue(np.allclose(squeeze(x, axis=2), np.squeeze(x, axis=2)))
@require_torch
def test_squeeze_torch(self):
x = np.random.randn(1, 3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
x = np.random.randn(1, 4, 1, 5)
t = torch.tensor(x)
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
@require_tf
def test_squeeze_tf(self):
x = np.random.randn(1, 3, 4)
t = tf.constant(x)
self.assertTrue(np.allclose(squeeze(x), squeeze(t).numpy()))
x = np.random.randn(1, 4, 1, 5)
t = tf.constant(x)
self.assertTrue(np.allclose(squeeze(x, axis=2), squeeze(t, axis=2).numpy()))
@require_flax
def test_squeeze_flax(self):
x = np.random.randn(1, 3, 4)
t = jnp.array(x)
self.assertTrue(np.allclose(squeeze(x), np.asarray(squeeze(t))))
x = np.random.randn(1, 4, 1, 5)
t = jnp.array(x)
self.assertTrue(np.allclose(squeeze(x, axis=2), np.asarray(squeeze(t, axis=2))))
def test_expand_dims_numpy(self):
x = np.random.randn(3, 4)
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.expand_dims(x, axis=1)))
@require_torch
def test_expand_dims_torch(self):
x = np.random.randn(3, 4)
t = torch.tensor(x)
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
@require_tf
def test_expand_dims_tf(self):
x = np.random.randn(3, 4)
t = tf.constant(x)
self.assertTrue(np.allclose(expand_dims(x, axis=1), expand_dims(t, axis=1).numpy()))
@require_flax
def test_expand_dims_flax(self):
x = np.random.randn(3, 4)
t = jnp.array(x)
self.assertTrue(np.allclose(expand_dims(x, axis=1), np.asarray(expand_dims(t, axis=1))))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment