"test/vscode:/vscode.git/clone" did not exist on "0513330a853fd2bb3196b89a7727f95c36d7335f"
Unverified Commit f8eda599 authored by Kristian Holsheimer's avatar Kristian Holsheimer Committed by GitHub
Browse files

[FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791)

* [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes

* [FlaxRoberta] Fix non-broadcastable attention mask

* Use jax.numpy instead of ordinary numpy (otherwise not jit-able)

* Partially revert "Use jax.numpy ..."

* Add tests for batched forward passes

* Avoid unnecessary OOMs due to preallocation of GPU memory by XLA

* Auto-fix style

* Re-enable GPU memory preallocation but with mem fraction < 1/paralleism
parent cb7602b3
...@@ -183,6 +183,10 @@ class FlaxBertAttention(nn.Module): ...@@ -183,6 +183,10 @@ class FlaxBertAttention(nn.Module):
@nn.compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask hidden_state, attention_mask
) )
......
...@@ -186,6 +186,10 @@ class FlaxRobertaAttention(nn.Module): ...@@ -186,6 +186,10 @@ class FlaxRobertaAttention(nn.Module):
@nn.compact @nn.compact
def __call__(self, hidden_state, attention_mask): def __call__(self, hidden_state, attention_mask):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask hidden_state, attention_mask
) )
...@@ -416,8 +420,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel): ...@@ -416,8 +420,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
token_type_ids = jnp.ones_like(input_ids) token_type_ids = jnp.ones_like(input_ids)
if position_ids is None: if position_ids is None:
position_ids = np.arange( position_ids = jnp.arange(
self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1 self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
) )
if attention_mask is None: if attention_mask is None:
......
import unittest import unittest
import pytest
from numpy import ndarray from numpy import ndarray
from transformers import BertTokenizerFast, TensorType, is_flax_available, is_torch_available from transformers import BertTokenizerFast, TensorType, is_flax_available, is_torch_available
...@@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch ...@@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch
if is_flax_available(): if is_flax_available():
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
import jax
from transformers.models.bert.modeling_flax_bert import FlaxBertModel from transformers.models.bert.modeling_flax_bert import FlaxBertModel
if is_torch_available(): if is_torch_available():
...@@ -39,3 +45,26 @@ class FlaxBertModelTest(unittest.TestCase): ...@@ -39,3 +45,26 @@ class FlaxBertModelTest(unittest.TestCase):
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum() diff = (a - b).sum()
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol)) self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
@require_flax
@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"])
def test_multiple_sentences(jit):
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
model = FlaxBertModel.from_pretrained("bert-base-cased")
sentences = ["this is an example sentence", "this is another", "and a third one"]
encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)
@jax.jit
def model_jitted(input_ids, attention_mask, token_type_ids):
return model(input_ids, attention_mask, token_type_ids)
if jit == "disable_jit":
with jax.disable_jit():
tokens, pooled = model_jitted(**encodings)
else:
tokens, pooled = model_jitted(**encodings)
assert tokens.shape == (3, 7, 768)
assert pooled.shape == (3, 768)
import unittest import unittest
import pytest
from numpy import ndarray from numpy import ndarray
from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available
...@@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch ...@@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch
if is_flax_available(): if is_flax_available():
import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
import jax
from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
if is_torch_available(): if is_torch_available():
...@@ -39,3 +45,26 @@ class FlaxRobertaModelTest(unittest.TestCase): ...@@ -39,3 +45,26 @@ class FlaxRobertaModelTest(unittest.TestCase):
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
diff = (a - b).sum() diff = (a - b).sum()
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol)) self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
@require_flax
@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"])
def test_multiple_sentences(jit):
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
model = FlaxRobertaModel.from_pretrained("roberta-base")
sentences = ["this is an example sentence", "this is another", "and a third one"]
encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)
@jax.jit
def model_jitted(input_ids, attention_mask):
return model(input_ids, attention_mask)
if jit == "disable_jit":
with jax.disable_jit():
tokens, pooled = model_jitted(**encodings)
else:
tokens, pooled = model_jitted(**encodings)
assert tokens.shape == (3, 7, 768)
assert pooled.shape == (3, 768)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment