Commit 92953afa authored by Baber's avatar Baber
Browse files

add cwe

parent 7ab8b057
include: niah_1.yaml
task: ruler_cwe
download_dataset: !function cwe_utils.get_cw_dataset
generation_kwargs:
do_sample: false
temperature: 0.0
max_gen_toks: 30
until: []
...@@ -11,16 +11,22 @@ ...@@ -11,16 +11,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License # limitations under the License
import itertools
import random import random
from functools import cache
import datasets
import wonderwords import wonderwords
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoTokenizer
from lm_eval.tasks.ruler.utils import SEQ_LENGTHS
RNG = random.Random(42) RNG = random.Random(42)
TEMPLATE = "" TEMPLATE = "Below is a numbered list of words. In these words, some appear more often than others. Memorize the ones that appear most often.\n{context}\nQuestion: What are the 10 most common words in the above list?\n\nAnswer: The top 10 words that appear most often in the list are:"
r = wonderwords.RandomWord() r = wonderwords.RandomWord()
WORDS = sorted( WORDS = sorted(
list( list(
...@@ -74,12 +80,12 @@ def generate_input_output( ...@@ -74,12 +80,12 @@ def generate_input_output(
def sys_word_pair_random( def sys_word_pair_random(
num_samples: int, num_samples: int,
max_seq_length: int, max_seq_length: int,
TOKENIZER=None, tokenizer=None,
incremental: int = 10, incremental: int = 10,
remove_newline_tab=False, remove_newline_tab=False,
tokens_to_generate=120, tokens_to_generate=30,
): ):
assert TOKENIZER is not None, "Tokenizer is not provided." assert tokenizer is not None, "Tokenizer is not provided."
write_jsons = [] write_jsons = []
tokens_to_generate = tokens_to_generate tokens_to_generate = tokens_to_generate
...@@ -93,13 +99,13 @@ def sys_word_pair_random( ...@@ -93,13 +99,13 @@ def sys_word_pair_random(
) )
# Calculate the number of tokens in the example # Calculate the number of tokens in the example
total_tokens = len( total_tokens = len(
TOKENIZER( tokenizer(
input_example input_example
+ "\n" + "\n"
+ input_text + input_text
+ " " + " "
+ " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)]) + " ".join([f"{i + 1}. {word}" for i, word in enumerate(answer)])
) ).input_ids
) )
print( print(
f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}" f"Max length {max_seq_length} | Current length {total_tokens + tokens_to_generate} | Words: {num_words}"
...@@ -123,7 +129,7 @@ def sys_word_pair_random( ...@@ -123,7 +129,7 @@ def sys_word_pair_random(
input_example, input_text, answer = generate_input_output( input_example, input_text, answer = generate_input_output(
used_words, max_seq_length used_words, max_seq_length
) )
length = len(TOKENIZER.text_to_tokens(input_text)) + tokens_to_generate length = len(tokenizer(input_text).input_ids) + tokens_to_generate
assert length <= max_seq_length, f"{length} exceeds max_seq_length." assert length <= max_seq_length, f"{length} exceeds max_seq_length."
break break
except: # noqa: E722 except: # noqa: E722
...@@ -138,13 +144,43 @@ def sys_word_pair_random( ...@@ -138,13 +144,43 @@ def sys_word_pair_random(
input_example.replace("\n", " ").replace("\t", " ").strip().split() input_example.replace("\n", " ").replace("\t", " ").strip().split()
) )
gen_prefix_index = input_text.rfind(" Answer")
gen_prefix = input_text[gen_prefix_index:].strip()
input_text = input_text[:gen_prefix_index]
formatted_output = { formatted_output = {
"index": index, "index": index,
"input": input_text, "input": input_text,
"input_example": input_example,
"outputs": answer, "outputs": answer,
"length": length, "length": length,
"max_length": max_seq_length, "max_length": max_seq_length,
"gen_prefix": gen_prefix,
} }
write_jsons.append(formatted_output) write_jsons.append(formatted_output)
return write_jsons return write_jsons
@cache
def get_tokenizer(pretrained):
return AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
def get_dataset(pretrained, seq=None, **kwargs):
tokenizer = get_tokenizer(pretrained)
write_jsons = sys_word_pair_random(
num_samples=500, max_seq_length=seq, tokenizer=tokenizer
)
return write_jsons
def get_cw_dataset(**kwargs):
kwargs = kwargs.get("metadata", {})
pretrained = kwargs.get("tokenizer", kwargs.get("pretrained", {}))
df = (get_dataset(pretrained, seq=seq) for seq in SEQ_LENGTHS)
return {
"test": datasets.Dataset.from_list(
list(itertools.chain.from_iterable(df)), split=datasets.Split.TEST
)
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment