Unverified Commit 79743ced authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

replaced assert with raise ValueError for t5, switch_transformers, pix2struct,...

replaced assert with raise ValueError for t5, switch_transformers, pix2struct, mt5, longt5, gptsan_japanese. (#23273)

* replaced assert with raise ValueError

* one liner

* reverse one liner and cache-decoder check
parent 291c5e9b
......@@ -767,10 +767,11 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, (
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
" See T5 docs for more information"
)
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
"See T5 docs for more information."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
......@@ -782,7 +783,8 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
......
......@@ -451,9 +451,10 @@ class LongT5Attention(nn.Module):
real_seq_length = seq_length
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
......@@ -1349,10 +1350,11 @@ class LongT5PreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, (
"self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id."
" See LongT5 docs for more information"
)
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id."
"See LongT5 docs for more information."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
......@@ -1364,7 +1366,8 @@ class LongT5PreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
......
......@@ -333,9 +333,10 @@ class MT5Attention(nn.Module):
real_seq_length = seq_length
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
......@@ -818,10 +819,11 @@ class MT5PreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, (
"self.model.config.decoder_start_token_id has to be defined. In MT5 it is usually set to the pad_token_id."
" See MT5 docs for more information"
)
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In MT5 it is usually set to the pad_token_id."
"See MT5 docs for more information."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
......@@ -833,7 +835,8 @@ class MT5PreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
......@@ -952,7 +955,8 @@ class MT5Stack(MT5PreTrainedModel):
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
......@@ -961,7 +965,8 @@ class MT5Stack(MT5PreTrainedModel):
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True:
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
......@@ -1852,8 +1857,14 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
......
......@@ -479,10 +479,11 @@ class Pix2StructPreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, (
"self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id."
" See Pix2Struct docs for more information"
)
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id."
"See Pix2Struct docs for more information."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
......@@ -494,7 +495,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
......@@ -1356,8 +1358,14 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
......
......@@ -515,9 +515,10 @@ class SwitchTransformersAttention(nn.Module):
real_seq_length = seq_length
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
......
......@@ -163,9 +163,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
......@@ -473,9 +472,10 @@ class T5Attention(nn.Module):
real_seq_length = seq_length
if past_key_value is not None:
assert (
len(past_key_value) == 2
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
if len(past_key_value) != 2:
raise ValueError(
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
......@@ -848,10 +848,11 @@ class T5PreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, (
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
" See T5 docs for more information"
)
if decoder_start_token_id is None:
raise ValueError(
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
"See T5 docs for more information."
)
# shift inputs to the right
if is_torch_fx_proxy(input_ids):
......@@ -863,7 +864,8 @@ class T5PreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined."
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
......@@ -981,7 +983,8 @@ class T5Stack(T5PreTrainedModel):
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape
......@@ -990,7 +993,8 @@ class T5Stack(T5PreTrainedModel):
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True:
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
......@@ -1817,8 +1821,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
)
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
assert len(reordered_layer_past_states) == len(layer_past_states)
if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment