"git@developer.sourcefind.cn:OpenDAS/rodnet.git" did not exist on "81df4598b6c022eba30ed8d86ea1bc86977d41e6"
Commit beb59080 authored by VictorSanh's avatar VictorSanh
Browse files

Fix size compatibility for model.forward

Error was coming from "modeling_pytorch.py", line 484, in forward: start_loss = loss_fct(start_logits, start_positions) --> ValueError: Expected target size (12, 1), got torch.Size([12])
parent 8cbe7d6a
...@@ -840,6 +840,9 @@ def main(): ...@@ -840,6 +840,9 @@ def main():
#label_ids = label_ids.to(device) #label_ids = label_ids.to(device)
start_positions = start_positions.to(device) start_positions = start_positions.to(device)
end_positions = start_positions.to(device) end_positions = start_positions.to(device)
start_positions = start_positions.view(-1, 1)
end_positions = end_positions.view(-1, 1)
loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions) loss, _ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
loss.backward() loss.backward()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment