Commit a2c8e516 authored by Patrick von Platen's avatar Patrick von Platen
Browse files

fix torch to tf translation

parent ca2047bc
......@@ -641,7 +641,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
# create attention mask if necessary
# TODO (PVP): this should later be handled by the forward fn() in each model in the future see PR 3140
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids):
import ipdb
ipdb.set_trace()
if (attention_mask is None) and (pad_token_id is not None) and (pad_token_id in input_ids.numpy()):
attention_mask = tf.cast(tf.math.not_equal(input_ids, pad_token_id), dtype=tf.int32)
elif attention_mask is None:
attention_mask = tf.ones_like(input_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment