Commit 92953afa authored by Baber's avatar Baber
Browse files

add cwe

parent 7ab8b057
include: niah_1.yaml
task: ruler_cwe
download_dataset: !function cwe_utils.get_cw_dataset
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 30
until: []
......@@ -11,16 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import itertools
import random
from functools import cache
import datasets
import wonderwords
from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
RNG = random.Random(42)
TEMPLATE = ""
TEMPLATE = "Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?\n\nAnswer: The top 10 words that appear most often in the list are:"
r = wonderwords.RandomWord()
WORDS = sorted(
list(
......@@ -74,12 +80,12 @@ def generate_input_output(
def sys_word_pair_random(
num_samples: int,
max_seq_length: int,
TOKENIZER=None,
tokenizer=None,
incremental: int = 10,
remove_newline_tab=False,
tokens_to_generate=120,
tokens_to_generate=30,
):
assert TOKENIZER is not None, "Tokenizer is not provided."
assert tokenizer is not None, "Tokenizer is not provided."
write_jsons = []
tokens_to_generate = tokens_to_generate
......@@ -93,13 +99,13 @@ def sys_word_pair_random(
)
# Calculate the number of tokens in the example
total_tokens = len(
TOKENIZER(
tokenizer(
input_example
+ "\n"
+ input_text
+ " "
+ " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
)
).input_ids
)
print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
......@@ -123,7 +129,7 @@ def sys_word_pair_random(
input_example, input_text, answer = generate_input_output(
used_words, max_seq_length
)
length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate
length = len(tokenizer(input_text).input_ids) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break
except: # noqa: E722
......@@ -138,13 +144,43 @@ def sys_word_pair_random(
input_example.replace("\n", " ").replace("\t", " ").strip().split()
)
gen_prefix_index = input_text.rfind(" Answer")
gen_prefix = input_text[gen_prefix_index:].strip()
input_text = input_text[:gen_prefix_index]
formatted_output = {
"index": index,
"input": input_text,
"input_example": input_example,
"outputs": answer,
"length": length,
"max_length": max_seq_length,
"gen_prefix": gen_prefix,
}
write_jsons.append(formatted_output)
return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, seq=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_word_pair_random(
num_samples=500, max_seq_length=seq, tokenizer=tokenizer
)
return write_jsons
def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment