"ts/webui/src/static/function.ts" did not exist on "995f625963a9e6bf76033e1cc8e7dcd4df3dbf65"
Commit af3cccc8 authored by Tian Yun's avatar Tian Yun
Browse files

Merge branch 'master' of https://github.com/cjlovering/ps-eh

parents 716c87d6 94218002
import abc import abc
from typing import Iterable from typing import Iterable, Optional
import promptsource import promptsource
import numpy as np import numpy as np
...@@ -348,17 +348,25 @@ class BaseLM(LM): ...@@ -348,17 +348,25 @@ class BaseLM(LM):
for context, until in tqdm(reord.get_reordered()): for context, until in tqdm(reord.get_reordered()):
if isinstance(until, str): if isinstance(until, str):
until = [until] until = [until]
max_length = None
elif isinstance(until, list) and len(until) == 2:
until, max_length = [until[0]], until[1]
elif isinstance(until, list):
max_length = None
# TODO: Come back to for generation `eos`.
primary_until = self.tok_encode(until[0]) primary_until = self.tok_encode(until[0])
context_enc = torch.tensor( context_enc = torch.tensor(
[self.tok_encode(context)[self.max_gen_toks - self.max_length :]] [self.tok_encode(context)[self.max_gen_toks - self.max_length :]]
).to(self.device) ).to(self.device)
if max_length is not None:
max_length = min(max_length, context_enc.shape[1] + self.max_gen_toks)
else:
max_length = context_enc.shape[1] + self.max_gen_toks
cont = self._model_generate( cont = self._model_generate(
context_enc, context_enc,
context_enc.shape[1] + self.max_gen_toks, max_length,
torch.tensor(primary_until), torch.tensor(primary_until),
) )
...@@ -652,7 +660,7 @@ class PromptSourceTask(Task): ...@@ -652,7 +660,7 @@ class PromptSourceTask(Task):
super().__init__(data_dir, cache_dir, download_mode) super().__init__(data_dir, cache_dir, download_mode)
self.prompt = prompt self.prompt = prompt
def stopping_criteria(self): def stopping_criteria(self) -> Optional[str]:
"""Denote where the generation should end. """Denote where the generation should end.
For example, for coqa, this is '\nQ:' and for drop '.'. For example, for coqa, this is '\nQ:' and for drop '.'.
...@@ -661,6 +669,10 @@ class PromptSourceTask(Task): ...@@ -661,6 +669,10 @@ class PromptSourceTask(Task):
""" """
return None return None
def max_generation_length(self) -> Optional[int]:
"""Denote where the max length of the generation if it is obvious from the task."""
return None
def is_generation_task(self): def is_generation_task(self):
return ( return (
"BLEU" in self.prompt.metadata.metrics "BLEU" in self.prompt.metadata.metrics
...@@ -718,7 +730,9 @@ class PromptSourceTask(Task): ...@@ -718,7 +730,9 @@ class PromptSourceTask(Task):
_requests.append(ll_answer_choice) _requests.append(ll_answer_choice)
else: else:
# TODO(Albert): What is the stop symbol? Is it model specific? # TODO(Albert): What is the stop symbol? Is it model specific?
cont_request = rf.greedy_until(ctx, [self.stopping_criteria()]) cont_request = rf.greedy_until(
ctx, [self.stopping_criteria(), self.max_generation_length()]
)
_requests.append(cont_request) _requests.append(cont_request)
return _requests return _requests
......
import typing
import math import math
from collections.abc import Iterable from collections.abc import Iterable
import numpy as np import numpy as np
import sacrebleu import sacrebleu
from rouge_score import rouge_scorer
import sklearn.metrics import sklearn.metrics
import random import random
...@@ -184,6 +186,65 @@ def _sacreformat(refs, preds): ...@@ -184,6 +186,65 @@ def _sacreformat(refs, preds):
return refs, preds return refs, preds
def rouge(
refs: typing.List[str],
pred: str,
rouge_types: typing.List[str] = ["rouge1", "rouge2", "rougeL", "rougeLsum"]
):
""" ROUGE with multi-reference support
Implementation based on GEM-metrics:
https://github.com/GEM-benchmark/GEM-metrics/blob/431a8174bd6b3637e8d6118bfad2983e39e99733/gem_metrics/rouge.py
:param refs:
A `list` of reference `str`s.
:param pred:
A single prediction `str`s.
"""
scorer = rouge_scorer.RougeScorer(rouge_types=rouge_types, use_stemmer=True)
# ROUGE multi-ref jackknifing
if len(refs) > 1:
cur_scores = [scorer.score(ref, pred) for ref in refs]
# get best score for all leave-one-out sets
best_scores = []
for leave in range(len(refs)):
cur_scores_leave_one = [
cur_scores[s] for s in range(len(refs)) if s != leave
]
best_scores.append(
{
rouge_type: max(
[s[rouge_type] for s in cur_scores_leave_one],
key=lambda s: s.fmeasure,
)
for rouge_type in rouge_types
}
)
# average the leave-one-out bests to produce the final score
score = {
rouge_type: rouge_scorer.scoring.Score(
np.mean([b[rouge_type].precision for b in best_scores]),
np.mean([b[rouge_type].recall for b in best_scores]),
np.mean([b[rouge_type].fmeasure for b in best_scores]),
)
for rouge_type in rouge_types
}
else:
score = scorer.score(refs[0], pred)
# convert the named tuples to plain nested dicts
score = {
rouge_type: {
"precision": score[rouge_type].precision,
"recall": score[rouge_type].recall,
"fmeasure": score[rouge_type].fmeasure,
}
for rouge_type in rouge_types
}
return score
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment