Unverified Commit fba0b6a8 authored by mrbean's avatar mrbean Committed by GitHub
Browse files

convert assertion to raised exception in debertav2 (#17619)

* convert assertion to raised exception in debertav2

* change assert to raise exception in deberta

* fix messages
parent da0bed5f
...@@ -778,7 +778,8 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer): ...@@ -778,7 +778,8 @@ class TFDebertaEmbeddings(tf.keras.layers.Layer):
Returns: Returns:
final_embeddings (`tf.Tensor`): output embedding tensor. final_embeddings (`tf.Tensor`): output embedding tensor.
""" """
assert not (input_ids is None and inputs_embeds is None) if input_ids is None and inputs_embeds is None:
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
if input_ids is not None: if input_ids is not None:
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
......
...@@ -876,7 +876,8 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer): ...@@ -876,7 +876,8 @@ class TFDebertaV2Embeddings(tf.keras.layers.Layer):
Returns: Returns:
final_embeddings (`tf.Tensor`): output embedding tensor. final_embeddings (`tf.Tensor`): output embedding tensor.
""" """
assert not (input_ids is None and inputs_embeds is None) if input_ids is None and inputs_embeds is None:
raise ValueError("Need to provide either `input_ids` or `input_embeds`.")
if input_ids is not None: if input_ids is not None:
inputs_embeds = tf.gather(params=self.weight, indices=input_ids) inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment