Commit c36f0de4 authored by zhougaofeng's avatar zhougaofeng
Browse files

Update Dockerfile, eval_gsm8k.py, eval_math.py, LICENSE, read.md, README.MD,...

Update Dockerfile, eval_gsm8k.py, eval_math.py, LICENSE, read.md, README.MD, requirements.txt, run_mistral.sh, run.sh, util.py, train_math.py, imgs/dcu.png, imgs/metamath.svg, data/README.md, data/test/MATH_test.jsonl, data/test/GSM8K_test.jsonl, data/test/GSM8K_Backward.jsonl, data/train/README.md, code_for_generating_data/ReadMe.md, code_for_generating_data/code/main_backward_reasoning.py, code_for_generating_data/code/main_forward_reasoning.py, code_for_generating_data/code/main_create_backward_questions.py, code_for_generating_data/code/main_rephrase_question.py, code_for_generating_data/code/path_init.py, code_for_generating_data/code/main_self_verification.py, code_for_generating_data/code/run_create_backward_questions.sh, code_for_generating_data/code/run_forward.sh, code_for_generating_data/code/run_backward.sh, code_for_generating_data/code/run_sv.sh, code_for_generating_data/code/run_rephrase.sh, code_for_generating_data/code/utils/__init__.py, code_for_generating_data/code/utils/answer_clean_utils.py, code_for_generating_data/code/utils/config_utils.py, code_for_generating_data/code/utils/log_utils.py, code_for_generating_data/code/utils/math_utils.py, code_for_generating_data/code/utils/openai_api_utils.py, code_for_generating_data/code/utils/parallel_utils.py, code_for_generating_data/code/utils/time_utils.py, code_for_generating_data/code/utils/path_utils.py, code_for_generating_data/configs/ansaug_cot_math.txt, code_for_generating_data/configs/ansaug_cot_gsm8k.txt, code_for_generating_data/configs/fobar_cot_gsm8k.txt, code_for_generating_data/configs/fobar_cot_math.txt, code_for_generating_data/configs/rephrase_cot_gsm8k.txt, code_for_generating_data/configs/rephrase_cot_math.txt, code_for_generating_data/configs/sv_cot_gsm8k.txt, code_for_generating_data/configs/sv_cot_math.txt, code_for_generating_data/configs/sv_rewrite_question_prompt_gsm8k.txt, code_for_generating_data/configs/sv_rewrite_question_prompt_math.txt, code_for_generating_data/data/gsm8k_train.json, code_for_generating_data/data/MATH_train.json files
parents
Pipeline #1598 canceled with stages
# MetaMathQA
The full **MetaMathQA** dataset is now released in the huggingface [MetaMathQA](https://huggingface.co/datasets/meta-math/MetaMathQA/tree/main)
import argparse
import json
import re
import jsonlines
from fraction import Fraction
from vllm import LLM, SamplingParams
import sys
MAX_INT = sys.maxsize
def is_number(s):
try:
float(s)
return True
except ValueError:
pass
try:
import unicodedata
unicodedata.numeric(s)
return True
except (TypeError, ValueError):
pass
return False
def extract_answer_number(completion):
text = completion.split('The answer is: ')
if len(text) > 1:
extract_ans = text[-1].strip()
match = re.search(r'[\-+]?\d*[\.,/]?\d+', extract_ans)
if match:
if '/' in match.group():
denominator = match.group().split('/')[1]
numerator = match.group().split('/')[0]
if is_number(denominator) == True and is_number(numerator) == True:
if denominator == '0':
return round(float(numerator.replace(',', '')))
else:
frac = Fraction(match.group().replace(',', ''))
num_numerator = frac.numerator
num_denominator = frac.denominator
return round(float(num_numerator / num_denominator))
else:
return None
else:
if float(match.group().replace(',', '')) == float('inf'):
return None
return round(float(match.group().replace(',', '')))
else:
return None
else:
return None
def batch_data(data_list, batch_size=1):
n = len(data_list) // batch_size
batch_data = []
for i in range(n-1):
start = i * batch_size
end = (i+1)*batch_size
batch_data.append(data_list[start:end])
last_start = (n-1) * batch_size
last_end = MAX_INT
batch_data.append(data_list[last_start:last_end])
return batch_data
def gsm8k_test(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
INVALID_ANS = "[invalid]"
gsm8k_ins = []
gsm8k_answers = []
problem_prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)
print('promt =====', problem_prompt)
with open(data_path,"r+", encoding="utf8") as f:
for idx, item in enumerate(jsonlines.Reader(f)):
temp_instr = problem_prompt.format(instruction=item["query"])
gsm8k_ins.append(temp_instr)
temp_ans = item['response'].split('#### ')[1]
temp_ans = int(temp_ans.replace(',', ''))
gsm8k_answers.append(temp_ans)
gsm8k_ins = gsm8k_ins[start:end]
gsm8k_answers = gsm8k_answers[start:end]
print('lenght ====', len(gsm8k_ins))
batch_gsm8k_ins = batch_data(gsm8k_ins, batch_size=batch_size)
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0.0, top_p=1, max_tokens=512, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
result = []
res_completions = []
for idx, (prompt, prompt_answer) in enumerate(zip(batch_gsm8k_ins, gsm8k_answers)):
if isinstance(prompt, list):
pass
else:
prompt = [prompt]
completions = llm.generate(prompt, sampling_params)
for output in completions:
prompt = output.prompt
generated_text = output.outputs[0].text
res_completions.append(generated_text)
invalid_outputs = []
for idx, (prompt, completion, prompt_answer) in enumerate(zip(gsm8k_ins, res_completions, gsm8k_answers)):
doc = {'question': prompt}
y_pred = extract_answer_number(completion)
if y_pred != None:
result.append(float(y_pred) == float(prompt_answer))
else:
result.append(False)
temp = {'question': prompt, 'output': completion, 'answer': prompt_answer}
invalid_outputs.append(temp)
acc = sum(result) / len(result)
print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
print('start===', start, ', end====', end)
print('gsm8k length====', len(result), ', gsm8k acc====', acc)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str) # model path
parser.add_argument("--data_file", type=str, default='') # data path
parser.add_argument("--start", type=int, default=0) #start index
parser.add_argument("--end", type=int, default=MAX_INT) # end index
parser.add_argument("--batch_size", type=int, default=400) # batch_size
parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
gsm8k_test(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
import argparse
import json
import pdb
import jsonlines
import util
from vllm import LLM, SamplingParams
import sys
MAX_INT = sys.maxsize
INVALID_ANS = "[invalid]"
invalid_outputs = []
def remove_boxed(s):
left = "\\boxed{"
try:
assert s[:len(left)] == left
assert s[-1] == "}"
return s[len(left):-1]
except:
return None
def process_results(doc, completion, answer):
split_ans = completion.split('The answer is: ')
if len(split_ans) > 1:
ans = split_ans[-1]
extract_ans_temp = ans.split('.\n')[0]
extract_ans_temp = extract_ans_temp.strip()
if len(extract_ans_temp)>0 and extract_ans_temp[-1] == '.':
extract_ans = extract_ans_temp[0:-1]
else:
extract_ans = extract_ans_temp
extract_ans = extract_ans.strip()
if util.is_equiv(extract_ans, answer):
return True
else:
return False
else:
temp = {'question': doc, 'output': completion, 'answer': answer}
invalid_outputs.append(temp)
return False
def batch_data(data_list, batch_size=1):
n = len(data_list) // batch_size
batch_data = []
for i in range(n-1):
start = i * batch_size
end = (i+1)*batch_size
batch_data.append(data_list[start:end])
last_start = (n-1) * batch_size
last_end = MAX_INT
batch_data.append(data_list[last_start:last_end])
return batch_data
def test_hendrycks_math(model, data_path, start=0, end=MAX_INT, batch_size=1, tensor_parallel_size=1):
hendrycks_math_ins = []
hendrycks_math_answers = []
problem_prompt = (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response: Let's think step by step."
)
print('promt =====', problem_prompt)
with open(data_path, "r+", encoding="utf8") as f:
for idx, item in enumerate(jsonlines.Reader(f)):
temp_instr = problem_prompt.format(instruction=item["instruction"])
hendrycks_math_ins.append(temp_instr)
solution = item['output']
temp_ans = remove_boxed(util.last_boxed_only_string(solution))
hendrycks_math_answers.append(temp_ans)
print('total length ===', len(hendrycks_math_ins))
hendrycks_math_ins = hendrycks_math_ins[start:end]
hendrycks_math_answers = hendrycks_math_answers[start:end]
print('lenght ====', len(hendrycks_math_ins))
batch_hendrycks_math_ins = batch_data(hendrycks_math_ins, batch_size=batch_size)
stop_tokens = ["Question:", "Question", "USER:", "USER", "ASSISTANT:", "ASSISTANT", "Instruction:", "Instruction", "Response:", "Response"]
sampling_params = SamplingParams(temperature=0, top_p=1, max_tokens=2048, stop=stop_tokens)
print('sampleing =====', sampling_params)
llm = LLM(model=model,tensor_parallel_size=tensor_parallel_size)
res_completions = []
for idx, (prompt, prompt_answer) in enumerate(zip(batch_hendrycks_math_ins, hendrycks_math_answers)):
if isinstance(prompt, list):
pass
else:
prompt = [prompt]
completions = llm.generate(prompt, sampling_params)
for output in completions:
prompt_temp = output.prompt
generated_text = output.outputs[0].text
res_completions.append(generated_text)
results = []
for idx, (prompt, completion, prompt_answer) in enumerate(zip(hendrycks_math_ins, res_completions, hendrycks_math_answers)):
res = process_results(prompt, completion, prompt_answer)
results.append(res)
acc = sum(results) / len(results)
print('len invalid outputs ====', len(invalid_outputs), ', valid_outputs===', invalid_outputs)
print('start===', start, ', end====',end)
print('length====', len(results), ', acc====', acc)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default='') # model path
parser.add_argument("--data_file", type=str, default='') # data path
parser.add_argument("--start", type=int, default=0) #start index
parser.add_argument("--end", type=int, default=MAX_INT) # end index
parser.add_argument("--batch_size", type=int, default=400) # batch_size
parser.add_argument("--tensor_parallel_size", type=int, default=8) # tensor_parallel_size
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
test_hendrycks_math(model=args.model, data_path=args.data_file, start=args.start, end=args.end, batch_size=args.batch_size, tensor_parallel_size=args.tensor_parallel_size)
This diff is collapsed.
## MetaMath
## 论文
`MetaMath: Bootstrap Your Own Mathematical Questions for Large Language Models`
- [https://arxiv.org/abs/2309.12284]
## 算法原理
MetaMath是专门针对数学推理进行微调的语言模型。具体来说,从多个角度重写数学问题来引导启动数学问题,形成了一个名为MetaMathQA的新数据集。然后在MetaMathQA上对LLaMA-2模型进行微调。在两个流行的数学推理基准测试(即GSM8K和MATH)上的实验结果表明,MetaMath在一系列开源LLMs中取得了显著的性能优势。
<div align=center>
<img src="./imgs/metamath.svg"/>
</div>
## 环境配置
### Docker(方法一)
此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取 docker 镜像的地址与使用步骤
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310
docker run -it --shm-size=1024G -v /parastor/home/MetaMath:/home/MetaMath -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name xver <your IMAGE ID> bash # <your IMAGE ID>为以上拉取的docker的镜像ID替换,本镜像为:c85ed27005f2
cd /home/MetaMath
pip install -r requirement.txt -iasd https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
### Dockerfile(方法二)
此处提供 dockerfile 的使用方法
```
docker build -t Metamath-df:latest .
docker run -it --shm-size=1024G -v /parastor/home/Metamath:/home/Metamath -v /opt/hyhal:/opt/hyhal --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name Metamath Metamath-df bash
pip install -r requirement.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
### Anaconda(方法三)
此处提供本地配置、编译的详细步骤,例如:
关于本项目 DCU 显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk24.04
python:python3.10
torch: 2.1.0
torchvision: 0.16.0
bitsandbytes: 0.42.0
```
`Tips:以上dtk驱动、python、paddle等DCU相关工具版本需要严格一一对应`
其它非深度学习库参照 requirements.txt 安装:
```
pip install -r requirements.txt
```
## 数据集
使用MetaMathQA数据集,具体文件为MetaMathQA-395K.json,可从[Huggingface](https://hf-mirror.com/datasets/meta-math/MetaMathQA/tree/main)下载。
本次训练使用MetaMathQA的迷你数据集[MetaMathQA-40K](https://huggingface.co/datasets/meta-math/MetaMathQA-40K)
## 模型下载
MetaMath-7B下载地址(https://hf-mirror.com/meta-math/MetaMath-7B-V1.0)
## 训练
该模型由于体量大小,最低三卡运行,推荐使用四卡训练微调
### 单机多卡
```
bash run.sh
```
## result
使用的加速卡:4张 DCU-K100AI-64G
<div align=center>
<img src="./imgs/dcu.png"/>
</div>
### 精度
测试数据:[MATH_test.jsonl],使用的加速卡:K100-64G,2卡训练。
根据测试结果情况填写表格:
| device | acc |
| :------: | :------: |
| DCU-K100AI | 0.1506 |
| GPU-A800 | 0.1396 |
### 算法类别
数学推理
### 热点应用行业
`数学,教育,金融`
## 源码仓库及问题反馈
- https://github.com/meta-math/MetaMath
## 参考资料
- https://github.com/meta-math/MetaMath
- https://huggingface.co/meta-math/MetaMath-7B-V1.0
export MODEL_PATH='your model path'
export SAVE_PATH='path/to/save'
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export WANDB_DISABLED=true
wandb offline
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env train_math.py \
--model_name_or_path $MODEL_PATH \
--data_path "your data path" \
--data_length 10000000 \
--bf16 True \
--output_dir $SAVE_PATH \
--num_train_epochs 3 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 4 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1000 \
--save_total_limit 2 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True
#精度测试(推荐使用八卡),若要改变卡数,需要到对应的py文件中修改代码
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
python eval_gsm8k.py --model $SAVE_PATH --data_path ./data/test/GSM8K_test.jsonl
python eval_math.py --model $SAVE_PATH --data_path ./data/test/MATH_test.jsonl
export MODEL_PATH='mistralai/Mistral-7B-v0.1'
export SAVE_PATH='path/to/save'
export MASTER_ADDR="localhost"
export MASTER_PORT="1231"
export GLOO_SOCKET_IFNAME="lo"
export NCCL_SOCKET_IFNAME="lo"
export WANDB_DISABLED=true
export HF_TOKEN="token of your huggingface"
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python3 -m torch.distributed.launch --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} --nproc_per_node=8 --use_env train_math.py \
--model_name_or_path $MODEL_PATH \
--data_path MetaMathQA-395K.json \
--data_length 10000000 \
--bf16 True \
--output_dir $SAVE_PATH \
--num_train_epochs 3 \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 8 \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 100000 \
--save_total_limit 0 \
--learning_rate 5e-6 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'MistralDecoderLayer' \
--tf32 True
python eval_gsm8k.py --model $SAVE_PATH --data_path ./data/test/GSM8K_test.jsonl
python eval_math.py --model $SAVE_PATH --data_path ./data/test/MATH_test.jsonl
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified by Zheng Yuan and Hongyi Yuan
import os
import copy
import logging
from dataclasses import dataclass, field
from typing import Optional, Dict, Sequence
import io
import torch
import transformers
from torch.utils.data import Dataset
from transformers import Trainer
import argparse
import json
import random;random.seed(42)
def _make_r_io_base(f, mode: str):
if not isinstance(f, io.IOBase):
f = open(f, mode=mode)
return f
def jload(f, mode="r"):
"""Load a .json file into a dictionary."""
f = _make_r_io_base(f, mode)
jdict = json.load(f)
f.close()
return jdict
IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "<s>"
DEFAULT_UNK_TOKEN = "<unk>"
PROMPT_DICT = {
"prompt_input": (
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
),
"prompt_no_input": (
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
),
}
#### 28
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class DataArguments:
data_path: str = field(default=None, metadata={"help": "Path to the training data."})
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
)
overwrite_output_dir: bool = field(default=True)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def smart_tokenizer_and_embedding_resize(
special_tokens_dict: Dict,
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""Tokenize a list of strings."""
tokenized_list = [
tokenizer(
text,
return_tensors="pt",
padding="longest",
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
input_ids_lens = labels_lens = [
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
]
return dict(
input_ids=input_ids,
labels=labels,
input_ids_lens=input_ids_lens,
labels_lens=labels_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""Preprocess the data by tokenizing."""
examples = [s + t for s, t in zip(sources, targets)]
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
label[:source_len] = IGNORE_INDEX
return dict(input_ids=input_ids, labels=labels)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_args, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
logging.warning("Loading data...")
data_path = data_args.data_path
try:
data_path = data_path_map[data_path]
except:
data_path = data_path
try:
list_data_dict = jload(data_path)
except BaseException:
with open(data_path, 'r') as f:
lines = f.readlines()
list_data_dict = [json.loads(line.strip()) for line in lines]
list_data_dict = random.sample(list_data_dict, len(list_data_dict))
list_data_dict = list_data_dict[:data_args.data_length]
# logging.warning("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
# print(list_data_dict[0])
if 'instruction' in list_data_dict[0]:
pass
else:
def get_input(query):
if query.find('\n') == -1:
return ''
return '\n'.join(query.split('\n')[1:])
list_data_dict = [{'instruction':data['query'].split('\n')[0], 'input':get_input(data['query']), 'output':data['response']} for data in list_data_dict]
# import ipdb; ipdb.set_trace()
sources = [
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
for example in list_data_dict
]
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
self.sources = sources
self.targets = targets
def __len__(self):
return len(self.sources)
def naive__getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
def __getitem__(self, i):
return dict(input_ids=self.sources[i], labels=self.targets[i])
@dataclass
class DataCollatorForSupervisedDataset(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
def naive__call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
sources = []
targets = []
for instance in instances:
source = instance['input_ids']
target = instance['labels']
sources.append(source)
targets.append(target)
data_dict = preprocess(sources, targets, self.tokenizer)
input_ids, labels = data_dict['input_ids'], data_dict['labels']
# input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
)
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
return dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
train_dataset = SupervisedDataset(tokenizer=tokenizer, data_args=data_args)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
def train():
parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args, remaining_args = parser.parse_args_into_dataclasses(return_remaining_strings=True)
data_args.data_length = int(remaining_args[1])
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
if tokenizer.pad_token is None:
smart_tokenizer_and_embedding_resize(
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
tokenizer=tokenizer,
model=model,
)
if "llama" in model_args.model_name_or_path:
tokenizer.add_special_tokens(
{
"eos_token": DEFAULT_EOS_TOKEN,
"bos_token": DEFAULT_BOS_TOKEN,
"unk_token": DEFAULT_UNK_TOKEN,
}
)
data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
trainer.train()
trainer.save_state()
# if os.environ.get('LOCAL_RANK') == '0':
safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
\ No newline at end of file
import pprint
def last_boxed_only(sample):
q, a = sample
a = last_boxed_only_string(a)
if a == None:
return None
return (q, a)
def last_boxed_only_string(string):
idx = string.rfind("\\boxed")
if idx < 0:
idx = string.rfind("\\fbox")
if idx < 0:
return None
i = idx
right_brace_idx = None
num_left_braces_open = 0
while i < len(string):
if string[i] == "{":
num_left_braces_open += 1
if string[i] == "}":
num_left_braces_open -= 1
if num_left_braces_open == 0:
right_brace_idx = i
break
i += 1
if right_brace_idx == None:
retval = None
else:
retval = string[idx:right_brace_idx + 1]
return retval
def only_until_first_boxed_from_tokens(string, tokens):
idx = string.find("\\boxed")
if idx < 0:
idx = string.find("\\fbox")
if idx < 0:
return None
cum_length = 0
for i, t in enumerate(tokens):
cum_length += len(t)
if cum_length >= idx:
break
return tokens[:i]
def clean_numbers(sample):
if not sample:
return None
new_sample = list()
for s in sample:
new_sample.append(_clean_numbers(s))
return tuple(new_sample)
def _clean_numbers(string):
"""
Clean Numbers in the given string
>>> _clean_numbers(None, "Hello 123")
'Hello 123'
>>> _clean_numbers(None, "Hello 1234")
'Hello 1,234'
>>> _clean_numbers(None, "Hello 1234324asdasd")
'Hello 1,234,324asdasd'
"""
num_prev_digits = 0
new_string = ""
for i, c in enumerate(string):
# isdigit() doesnt work here because of weird unicode chars.
if c in {'1', '2', '3', '4', '5', '6', '7', '8', '9', '0'}:
num_prev_digits += 1
else:
if num_prev_digits > 3:
# Some fixing
string_number = new_string[-num_prev_digits:]
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
num_prev_digits = 0
new_string += c
if num_prev_digits > 3:
# Some fixing
string_number = new_string[-num_prev_digits:]
new_string = new_string[:-num_prev_digits] + "{0:,}".format(int(string_number))
return new_string
def fix_fracs(string):
substrs = string.split("\\frac")
new_str = substrs[0]
if len(substrs) > 1:
substrs = substrs[1:]
for substr in substrs:
new_str += "\\frac"
if substr[0] == "{":
new_str += substr
else:
try:
assert len(substr) >= 2
except AssertionError:
return string
a = substr[0]
b = substr[1]
if b != "{":
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}{" + b + "}" + post_substr
else:
new_str += "{" + a + "}{" + b + "}"
else:
if len(substr) > 2:
post_substr = substr[2:]
new_str += "{" + a + "}" + b + post_substr
else:
new_str += "{" + a + "}" + b
string = new_str
return string
def fix_a_slash_b(string):
if len(string.split("/")) != 2:
return string
a = string.split("/")[0]
b = string.split("/")[1]
try:
a = int(a)
b = int(b)
assert string == "{}/{}".format(a, b)
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
return new_string
except AssertionError:
return string
def remove_right_units(string):
# "\\text{ " only ever occurs (at least in the val set) when describing units
if "\\text{ " in string:
splits = string.split("\\text{ ")
assert len(splits) == 2
return splits[0]
else:
return string
def fix_sqrt(string):
if "\\sqrt" not in string:
return string
splits = string.split("\\sqrt")
new_string = splits[0]
for split in splits[1:]:
if split[0] != "{":
a = split[0]
new_substr = "\\sqrt{" + a + "}" + split[1:]
else:
new_substr = "\\sqrt" + split
new_string += new_substr
return new_string
def strip_string(string):
# linebreaks
string = string.replace("\n", "")
# remove inverse spaces
string = string.replace("\\!", "")
# replace \\ with \
string = string.replace("\\\\", "\\")
# replace tfrac and dfrac with frac
string = string.replace("tfrac", "frac")
string = string.replace("dfrac", "frac")
# remove \left and \right
string = string.replace("\\left", "")
string = string.replace("\\right", "")
# Remove circ (degrees)
string = string.replace("^{\\circ}", "")
string = string.replace("^\\circ", "")
# remove dollar signs
string = string.replace("\\$", "")
# remove units (on the right)
string = remove_right_units(string)
# remove percentage
string = string.replace("\\%", "")
string = string.replace("\%", "") # noqa: W605
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
string = string.replace(" .", " 0.")
string = string.replace("{.", "{0.")
# if empty, return empty string
if len(string) == 0:
return string
if string[0] == ".":
string = "0" + string
# to consider: get rid of e.g. "k = " or "q = " at beginning
if len(string.split("=")) == 2:
if len(string.split("=")[0]) <= 2:
string = string.split("=")[1]
# fix sqrt3 --> sqrt{3}
string = fix_sqrt(string)
# remove spaces
string = string.replace(" ", "")
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
string = fix_fracs(string)
# manually change 0.5 --> \frac{1}{2}
if string == "0.5":
string = "\\frac{1}{2}"
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
string = fix_a_slash_b(string)
return string
def is_equiv(str1, str2, verbose=False):
if str1 is None and str2 is None:
print("WARNING: Both None")
return True
if str1 is None or str2 is None:
return False
try:
ss1 = strip_string(str1)
ss2 = strip_string(str2)
#pdb.set_trace()
if verbose:
print(ss1, ss2)
return ss1 == ss2
except Exception:
return str1 == str2
class NotEqual:
def __eq__(self, other):
return False
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment