Unverified Commit 79743ced authored by Susnato Dhar's avatar Susnato Dhar Committed by GitHub
Browse files

replaced assert with raise ValueError for t5, switch_transformers, pix2struct,...

replaced assert with raise ValueError for t5, switch_transformers, pix2struct, mt5, longt5, gptsan_japanese. (#23273)

* replaced assert with raise ValueError

* one liner

* reverse one liner and cache-decoder check
parent 291c5e9b
...@@ -767,10 +767,11 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel): ...@@ -767,10 +767,11 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, ( if decoder_start_token_id is None:
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." raise ValueError(
" See T5 docs for more information" "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
) "See T5 docs for more information."
)
# shift inputs to the right # shift inputs to the right
if is_torch_fx_proxy(input_ids): if is_torch_fx_proxy(input_ids):
...@@ -782,7 +783,8 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel): ...@@ -782,7 +783,8 @@ class GPTSanJapanesePreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
......
...@@ -451,9 +451,10 @@ class LongT5Attention(nn.Module): ...@@ -451,9 +451,10 @@ class LongT5Attention(nn.Module):
real_seq_length = seq_length real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
assert ( if len(past_key_value) != 2:
len(past_key_value) == 2 raise ValueError(
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
...@@ -1349,10 +1350,11 @@ class LongT5PreTrainedModel(PreTrainedModel): ...@@ -1349,10 +1350,11 @@ class LongT5PreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, ( if decoder_start_token_id is None:
"self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id." raise ValueError(
" See LongT5 docs for more information" "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id."
) "See LongT5 docs for more information."
)
# shift inputs to the right # shift inputs to the right
if is_torch_fx_proxy(input_ids): if is_torch_fx_proxy(input_ids):
...@@ -1364,7 +1366,8 @@ class LongT5PreTrainedModel(PreTrainedModel): ...@@ -1364,7 +1366,8 @@ class LongT5PreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
......
...@@ -333,9 +333,10 @@ class MT5Attention(nn.Module): ...@@ -333,9 +333,10 @@ class MT5Attention(nn.Module):
real_seq_length = seq_length real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
assert ( if len(past_key_value) != 2:
len(past_key_value) == 2 raise ValueError(
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
...@@ -818,10 +819,11 @@ class MT5PreTrainedModel(PreTrainedModel): ...@@ -818,10 +819,11 @@ class MT5PreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, ( if decoder_start_token_id is None:
"self.model.config.decoder_start_token_id has to be defined. In MT5 it is usually set to the pad_token_id." raise ValueError(
" See MT5 docs for more information" "self.model.config.decoder_start_token_id has to be defined. In MT5 it is usually set to the pad_token_id."
) "See MT5 docs for more information."
)
# shift inputs to the right # shift inputs to the right
if is_torch_fx_proxy(input_ids): if is_torch_fx_proxy(input_ids):
...@@ -833,7 +835,8 @@ class MT5PreTrainedModel(PreTrainedModel): ...@@ -833,7 +835,8 @@ class MT5PreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
...@@ -952,7 +955,8 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -952,7 +955,8 @@ class MT5Stack(MT5PreTrainedModel):
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
...@@ -961,7 +965,8 @@ class MT5Stack(MT5PreTrainedModel): ...@@ -961,7 +965,8 @@ class MT5Stack(MT5PreTrainedModel):
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
...@@ -1852,8 +1857,14 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): ...@@ -1852,8 +1857,14 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel):
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
) )
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
assert len(reordered_layer_past_states) == len(layer_past_states) raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past return reordered_decoder_past
......
...@@ -479,10 +479,11 @@ class Pix2StructPreTrainedModel(PreTrainedModel): ...@@ -479,10 +479,11 @@ class Pix2StructPreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, ( if decoder_start_token_id is None:
"self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id." raise ValueError(
" See Pix2Struct docs for more information" "self.model.config.decoder_start_token_id has to be defined. In Pix2Struct it is usually set to the pad_token_id."
) "See Pix2Struct docs for more information."
)
# shift inputs to the right # shift inputs to the right
if is_torch_fx_proxy(input_ids): if is_torch_fx_proxy(input_ids):
...@@ -494,7 +495,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel): ...@@ -494,7 +495,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
...@@ -1356,8 +1358,14 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): ...@@ -1356,8 +1358,14 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
) )
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
assert len(reordered_layer_past_states) == len(layer_past_states) raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past return reordered_decoder_past
......
...@@ -515,9 +515,10 @@ class SwitchTransformersAttention(nn.Module): ...@@ -515,9 +515,10 @@ class SwitchTransformersAttention(nn.Module):
real_seq_length = seq_length real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
assert ( if len(past_key_value) != 2:
len(past_key_value) == 2 raise ValueError(
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
......
...@@ -163,9 +163,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ...@@ -163,9 +163,8 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
logger.info(f"Transposing numpy weight of shape {array.shape} for {name}") logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
array = np.transpose(array) array = np.transpose(array)
try: try:
assert ( if pointer.shape != array.shape:
pointer.shape == array.shape raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e: except AssertionError as e:
e.args += (pointer.shape, array.shape) e.args += (pointer.shape, array.shape)
raise raise
...@@ -473,9 +472,10 @@ class T5Attention(nn.Module): ...@@ -473,9 +472,10 @@ class T5Attention(nn.Module):
real_seq_length = seq_length real_seq_length = seq_length
if past_key_value is not None: if past_key_value is not None:
assert ( if len(past_key_value) != 2:
len(past_key_value) == 2 raise ValueError(
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
...@@ -848,10 +848,11 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -848,10 +848,11 @@ class T5PreTrainedModel(PreTrainedModel):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
assert decoder_start_token_id is not None, ( if decoder_start_token_id is None:
"self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." raise ValueError(
" See T5 docs for more information" "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id."
) "See T5 docs for more information."
)
# shift inputs to the right # shift inputs to the right
if is_torch_fx_proxy(input_ids): if is_torch_fx_proxy(input_ids):
...@@ -863,7 +864,8 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -863,7 +864,8 @@ class T5PreTrainedModel(PreTrainedModel):
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
shifted_input_ids[..., 0] = decoder_start_token_id shifted_input_ids[..., 0] = decoder_start_token_id
assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id` # replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
...@@ -981,7 +983,8 @@ class T5Stack(T5PreTrainedModel): ...@@ -981,7 +983,8 @@ class T5Stack(T5PreTrainedModel):
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None: if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings")
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
...@@ -990,7 +993,8 @@ class T5Stack(T5PreTrainedModel): ...@@ -990,7 +993,8 @@ class T5Stack(T5PreTrainedModel):
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
...@@ -1817,8 +1821,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1817,8 +1821,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
) )
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
assert len(reordered_layer_past_states) == len(layer_past_states) raise ValueError(
f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
)
if len(reordered_layer_past_states) != len(layer_past_states):
raise ValueError(
f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return reordered_decoder_past return reordered_decoder_past
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment