Unverified Commit b65c3897 authored by Lukas Weiner's avatar Lukas Weiner Committed by GitHub
Browse files

Raise exceptions instead of asserts in...

Raise exceptions instead of asserts in src/transformers/models/bart/modeling_flax_[bart, marian, mbart, pegasus].py (#13939)

* Raise exceptions instead of asserts

* fix: fixed failing quality check with copies

* fix: fixed max line length

* rerun github ci, failed to install dependencies
parent 7fb2a8b3
......@@ -237,9 +237,11 @@ class FlaxBartAttention(nn.Module):
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
......
......@@ -241,9 +241,11 @@ class FlaxMarianAttention(nn.Module):
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
......
......@@ -248,9 +248,11 @@ class FlaxMBartAttention(nn.Module):
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
......
......@@ -241,9 +241,11 @@ class FlaxPegasusAttention(nn.Module):
def setup(self) -> None:
self.head_dim = self.embed_dim // self.num_heads
assert (
self.head_dim * self.num_heads == self.embed_dim
), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})."
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
f" and `num_heads`: {self.num_heads})."
)
dense = partial(
nn.Dense,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment