"...resnet50_tensorflow.git" did not exist on "de6f182c6b1d2f03a930fc9b81ddf70af2702e22"
Unverified Commit 22838f19 authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #1668 from tlkh/fix-tf-xlm

Fixed training for TF XLM
parents 7f84fc57 842f3bf0
...@@ -84,7 +84,8 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32): ...@@ -84,7 +84,8 @@ def get_masks(slen, lengths, causal, padding_mask=None, dtype=tf.float32):
attn_mask = mask attn_mask = mask
# sanity check # sanity check
assert shape_list(mask) == [bs, slen] # assert shape_list(mask) == [bs, slen]
tf.debugging.assert_equal(shape_list(mask), [bs, slen])
assert causal is False or shape_list(attn_mask) == [bs, slen, slen] assert causal is False or shape_list(attn_mask) == [bs, slen, slen]
mask = tf.cast(mask, dtype=dtype) mask = tf.cast(mask, dtype=dtype)
...@@ -318,7 +319,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -318,7 +319,8 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
# check inputs # check inputs
bs, slen = shape_list(input_ids) bs, slen = shape_list(input_ids)
assert shape_list(lengths)[0] == bs # assert shape_list(lengths)[0] == bs
tf.debugging.assert_equal(shape_list(lengths)[0], bs)
# assert lengths.max().item() <= slen # assert lengths.max().item() <= slen
# input_ids = input_ids.transpose(0, 1) # batch size as dimension 0 # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
# assert (src_enc is None) == (src_len is None) # assert (src_enc is None) == (src_len is None)
...@@ -335,12 +337,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer): ...@@ -335,12 +337,14 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
if position_ids is None: if position_ids is None:
position_ids = tf.expand_dims(tf.range(slen), axis=0) position_ids = tf.expand_dims(tf.range(slen), axis=0)
else: else:
assert shape_list(position_ids) == [bs, slen] # (slen, bs) # assert shape_list(position_ids) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(position_ids), [bs, slen])
# position_ids = position_ids.transpose(0, 1) # position_ids = position_ids.transpose(0, 1)
# langs # langs
if langs is not None: if langs is not None:
assert shape_list(langs) == [bs, slen] # (slen, bs) # assert shape_list(langs) == [bs, slen] # (slen, bs)
tf.debugging.assert_equal(shape_list(langs), [bs, slen])
# langs = langs.transpose(0, 1) # langs = langs.transpose(0, 1)
# Prepare head mask if needed # Prepare head mask if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment