Unverified Commit 04cddaf4 authored by Jake Tae's avatar Jake Tae Committed by GitHub
Browse files

refactor: replace `assert` with `ValueError` (#14970)

parent 600496fa
...@@ -114,13 +114,8 @@ def load_tf_weights_in_bert_generation( ...@@ -114,13 +114,8 @@ def load_tf_weights_in_bert_generation(
else: else:
model_pointer = model_pointer.weight model_pointer = model_pointer.weight
try: if model_pointer.shape != array.shape:
assert ( raise ValueError(f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched")
model_pointer.shape == array.shape
), f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (model_pointer.shape, array.shape)
raise
logger.info(f"Initialize PyTorch weight {key}") logger.info(f"Initialize PyTorch weight {key}")
model_pointer.data = torch.from_numpy(array.astype(np.float32)) model_pointer.data = torch.from_numpy(array.astype(np.float32))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment