Unverified Commit 9e4ea251 authored by Weston King-Leatham's avatar Weston King-Leatham Committed by GitHub
Browse files

Change asserts in src/transformers/models/xlnet/ to raise ValueError (#14088)



* Change asserts in src/transformers/models/xlnet/ to raise ValueError

* Update src/transformers/models/xlnet/modeling_tf_xlnet.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent e9d2a639
...@@ -180,11 +180,13 @@ class XLNetConfig(PretrainedConfig): ...@@ -180,11 +180,13 @@ class XLNetConfig(PretrainedConfig):
self.d_model = d_model self.d_model = d_model
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
assert d_model % n_head == 0 if d_model % n_head != 0:
raise ValueError(f"'d_model % n_head' ({d_model % n_head}) should be equal to 0")
if "d_head" in kwargs: if "d_head" in kwargs:
assert ( if kwargs["d_head"] != d_model // n_head:
kwargs["d_head"] == d_model // n_head raise ValueError(
), f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})" f"`d_head` ({kwargs['d_head']}) should be equal to `d_model // n_head` ({d_model // n_head})"
)
self.d_head = d_model // n_head self.d_head = d_model // n_head
self.ff_activation = ff_activation self.ff_activation = ff_activation
self.d_inner = d_inner self.d_inner = d_inner
......
...@@ -561,7 +561,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -561,7 +561,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len) bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
if bsz is not None: if bsz is not None:
assert bsz % 2 == 0, f"With bi_data, the batch size {bsz} should be divisible by 2" if bsz % 2 != 0:
raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2")
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2) fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2) bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment