Unverified Commit 790cdc2e authored by nbertagnolli's avatar nbertagnolli Committed by GitHub
Browse files

Raise exceptions instead of using asserts in modeling_openai #12789 (#14386)

* Raise exceptions instead of using asserts for control flow in modeling_openai #12789

* reformatted file
parent 2e60276b
...@@ -83,13 +83,16 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): ...@@ -83,13 +83,16 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
# del init_params[1] # del init_params[1]
init_params = [arr.squeeze() for arr in init_params] init_params = [arr.squeeze() for arr in init_params]
try: # Check that the token and position embeddings weight dimensions map those of the init parameters.
assert model.tokens_embed.weight.shape == init_params[1].shape if model.tokens_embed.weight.shape != init_params[1].shape:
assert model.positions_embed.weight.shape == init_params[0].shape raise ValueError(
except AssertionError as e: f"tokens_embed.weight.shape: {model.tokens_embed.weight.shape} does not match init_param[1].shape: {init_params[1].shape}"
e.args += (model.tokens_embed.weight.shape, init_params[1].shape) )
e.args += (model.positions_embed.weight.shape, init_params[0].shape)
raise if model.positions_embed.weight.shape != init_params[0].shape:
raise ValueError(
f"positions_embed.weight.shape: {model.positions_embed.weight.shape} does not match init_param[0].shape: {init_params[0].shape}"
)
model.tokens_embed.weight.data = torch.from_numpy(init_params[1]) model.tokens_embed.weight.data = torch.from_numpy(init_params[1])
model.positions_embed.weight.data = torch.from_numpy(init_params[0]) model.positions_embed.weight.data = torch.from_numpy(init_params[0])
...@@ -100,7 +103,8 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): ...@@ -100,7 +103,8 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]): for name, array in zip(names, init_params): # names[1:n_transfer], init_params[1:n_transfer]):
name = name[6:] # skip "model/" name = name[6:] # skip "model/"
assert name[-2:] == ":0" if name[-2:] != ":0":
raise ValueError(f"Layer {name} does not end with :0")
name = name[:-2] name = name[:-2]
name = name.split("/") name = name.split("/")
pointer = model pointer = model
...@@ -120,20 +124,11 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path): ...@@ -120,20 +124,11 @@ def load_tf_weights_in_openai_gpt(model, config, openai_checkpoint_folder_path):
if len(scope_names) >= 2: if len(scope_names) >= 2:
num = int(scope_names[1]) num = int(scope_names[1])
pointer = pointer[num] pointer = pointer[num]
try:
assert ( # Ensure that the pointer and array have compatible shapes.
pointer.shape == array.shape if pointer.shape != array.shape:
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {name}") logger.info(f"Initialize PyTorch weight {name}")
pointer.data = torch.from_numpy(array) pointer.data = torch.from_numpy(array)
return model return model
...@@ -147,7 +142,8 @@ class Attention(nn.Module): ...@@ -147,7 +142,8 @@ class Attention(nn.Module):
super().__init__() super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implementation] # [switch nx => n_state from Block to Attention to keep identical to TF implementation]
assert n_state % config.n_head == 0 if n_state % config.n_head != 0:
raise ValueError(f"Attention n_state shape: {n_state} must be divisible by config.n_head {config.n_head}")
self.register_buffer( self.register_buffer(
"bias", torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions) "bias", torch.tril(torch.ones(n_positions, n_positions)).view(1, 1, n_positions, n_positions)
) )
...@@ -804,9 +800,10 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): ...@@ -804,9 +800,10 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
else: else:
batch_size, sequence_length = inputs_embeds.shape[:2] batch_size, sequence_length = inputs_embeds.shape[:2]
assert ( # Ensure the batch size is > 1 if there is no padding.
self.config.pad_token_id is not None or batch_size == 1 if self.config.pad_token_id is None and batch_size != 1:
), "Cannot handle batch sizes > 1 if no padding token is defined." raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
if self.config.pad_token_id is None: if self.config.pad_token_id is None:
sequence_lengths = -1 sequence_lengths = -1
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment