"...composable_kernel_rocm.git" did not exist on "7409674a8987f832a6c9069515f531b29a241c31"
Commit f4f7618a authored by Jon Tow's avatar Jon Tow
Browse files

Fixes

parent 93ebce43
...@@ -132,7 +132,7 @@ class DROP(Task): ...@@ -132,7 +132,7 @@ class DROP(Task):
for pred_index, pred_bag in enumerate(pred_bags): for pred_index, pred_bag in enumerate(pred_bags):
print(self._is_number_match(gold_bag, pred_bag)) print(self._is_number_match(gold_bag, pred_bag))
if self._is_number_match(gold_bag, pred_bag): if self._is_number_match(gold_bag, pred_bag):
scores[gold_index, pred_index] = self._bag_f1(pred_bag, gold_bag) scores[gold_index, pred_index] = self._bag_f1(gold_bag, pred_bag)
print(scores) print(scores)
row_ind, col_ind = linear_sum_assignment(-scores) row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold_bags), len(pred_bags))]) max_scores = np.zeros([max(len(gold_bags), len(pred_bags))])
...@@ -158,6 +158,7 @@ class DROP(Task): ...@@ -158,6 +158,7 @@ class DROP(Task):
def tokenize(text): def tokenize(text):
return re.split(" |-", text) return re.split(" |-", text)
tokens = [squad_metrics.normalize_answer(token) for token in tokenize(answer)] tokens = [squad_metrics.normalize_answer(token) for token in tokenize(answer)]
tokens = [token for token in tokens if token.strip()]
normalized = " ".join(tokens).strip() normalized = " ".join(tokens).strip()
return normalized return normalized
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment