"vscode:/vscode.git/clone" did not exist on "3c6d818db574efaad360359eea314b4ef7f73111"
Unverified Commit 04cddaf4 authored by Jake Tae's avatar Jake Tae Committed by GitHub
Browse files

refactor: replace `assert` with `ValueError` (#14970)

parent 600496fa
......@@ -114,13 +114,8 @@ def load_tf_weights_in_bert_generation(
else:
model_pointer = model_pointer.weight
try:
assert (
model_pointer.shape == array.shape
), f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched"
except AssertionError as e:
e.args += (model_pointer.shape, array.shape)
raise
if model_pointer.shape != array.shape:
raise ValueError(f"Pointer shape {model_pointer.shape} and array shape {array.shape} mismatched")
logger.info(f"Initialize PyTorch weight {key}")
model_pointer.data = torch.from_numpy(array.astype(np.float32))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment