Unverified Commit 790cdc2e authored by nbertagnolli's avatar nbertagnolli Committed by GitHub
Browse files

Raise exceptions instead of using asserts in modeling_openai #12789 (#14386)

* Raise exceptions instead of using asserts for control flow in modeling_openai #12789

* reformatted file
parent 2e60276b
......@@ -83,13 +83,16 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
# del init_params[1]
init_params = [arr.squeeze() for arr in init_params]
try:
assert model.tokens_embed.weight.shape == init_params[1].shape
assert model.positions_embed.weight.shape == init_params[0].shape
except AssertionError as e:
e.args += (model.tokens_embed.weight.shape, init_params[1].shape)
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
raise
# Check that the token and position embeddings weight dimensions map those of the init parameters.
if model.tokens_embed.weight.shape != init_params[1].shape:
raise ValueError(
f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape: {init_params[1].shape}"
)
if model.positions_embed.weight.shape != init_params[0].shape:
raise ValueError(
f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape: {init_params[0].shape}"
)
model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
model.positions_embed.weight.data = torch.from_numpy(init_params[0])
......@@ -100,7 +103,8 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/"
assert name[-2:] == ":0"
if name[-2:] != ":0":
raise ValueError(f"Layer {name} does not end with :0")
name = name[:-2]
name = name.split("/")
pointer = model
......@@ -120,20 +124,11 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
if len(scope_names) >= 2:
num = int(scope_names[1])
pointer = pointer[num]
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
# Ensure that the pointer and array have compatible shapes.
if pointer.shape != array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array)
return model
......@@ -147,7 +142,8 @@ class Attention(nn.Module):
super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implementation]
assert n_state % config.n_head == 0
if n_state % config.n_head != 0:
raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
self.register_buffer(
"bias", torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions)
)
......@@ -804,9 +800,10 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
else:
batch_size, sequence_length = inputs_embeds.shape[:2]
assert (
self.config.pad_token_id is not None or batch_size == 1
), "Cannot handle batch sizes > 1 if no padding token is defined."
# Ensure the batch size is > 1 if there is no padding.
if self.config.pad_token_id is None and batch_size != 1:
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None:
sequence_lengths = -1
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment