"vscode:/vscode.git/clone" did not exist on "44bb72ca0a909c44849355666db8e088c4591c5c"
Unverified Commit 858b7d58 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TF Longformer] Improve Speed for TF Longformer (#6447)

* add tf graph compile tests

* fix conflict

* remove more tf transpose statements

* fix conflicts

* fix comment typos

* move function to class function

* fix black

* fix black

* make style
parent a75c64d8
......@@ -1088,7 +1088,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel, TFSequenceClassific
config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier"
)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING)
@add_start_docstrings_to_callable(BERT_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-cased",
......
......@@ -677,7 +677,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
self.electra = TFElectraMainLayer(config, name="electra")
self.classifier = TFElectraClassificationHead(config, name="classifier")
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, num_choices, sequence_length)"))
@add_start_docstrings_to_callable(ELECTRA_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator",
......
This diff is collapsed.
......@@ -110,15 +110,6 @@ class TFModelTesterMixin:
def test_initialization(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# configs_no_init = _config_zero_init(config)
# for model_class in self.all_model_classes:
# model = model_class(config=configs_no_init)
# for name, param in model.named_parameters():
# if param.requires_grad:
# self.assertIn(param.data.mean().item(), [0.0, 1.0],
# msg="Parameter {} of model {} seems not properly initialized".format(name, model_class))
def test_save_load(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......@@ -134,6 +125,19 @@ class TFModelTesterMixin:
self.assert_outputs_same(after_outputs, outputs)
def test_graph_mode(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
inputs = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@tf.function
def run_in_graph_mode():
return model(inputs)
outputs = run_in_graph_mode()
self.assertIsNotNone(outputs)
@slow
def test_saved_model_with_hidden_states_output(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
......@@ -385,15 +385,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
self.assertTrue(shape_list(hidden_states), [1, 8, 4])
# pad along seq length dim
paddings = tf.constant([[0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
paddings = tf.constant([[0, 0], [0, 0], [0, 1], [0, 0]], dtype=tf.dtypes.int32)
hidden_states = TFLongformerSelfAttention._chunk(hidden_states, window_overlap=2)
padded_hidden_states = TFLongformerSelfAttention._pad_and_transpose_last_two_dims(hidden_states, paddings)
self.assertTrue(shape_list(padded_hidden_states) == [1, 8, 5])
self.assertTrue(shape_list(padded_hidden_states) == [1, 1, 8, 5])
expected_added_dim = tf.zeros((5,), dtype=tf.dtypes.float32)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, -1, :], rtol=1e-6)
tf.debugging.assert_near(expected_added_dim, padded_hidden_states[0, 0, -1, :], rtol=1e-6)
tf.debugging.assert_near(
hidden_states[0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
hidden_states[0, 0, -1, :], tf.reshape(padded_hidden_states, (1, -1))[0, 24:32], rtol=1e-6
)
def test_mask_invalid_locations(self):
......@@ -437,10 +438,16 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
attention_mask = tf.zeros((batch_size, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask = tf.where(tf.range(4)[None, None, None, :] > 1, -10000.0, attention_mask)
attention_mask = tf.zeros((batch_size, seq_length), dtype=tf.dtypes.float32)
is_index_global_attn = tf.math.greater(attention_mask, 1)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
expected_slice = tf.convert_to_tensor(
[0.00188, 0.012196, -0.017051, -0.025571, -0.02996, 0.017297, -0.011521, 0.004848], dtype=tf.dtypes.float32
......@@ -461,12 +468,18 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, None, None, :] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, None, None, :] > 0, 10000.0, attention_mask_2)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
output_hidden_states = layer([hidden_states, attention_mask, None])[0]
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
expected_slice_0 = tf.convert_to_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment