Commit 21357f54 authored by wanglch's avatar wanglch
Browse files

Initial commit

parents
Pipeline #987 failed with stages
in 0 seconds
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
tokenizer.json filter=lfs diff=lfs merge=lfs -text
_site/
.sass-cache/
.jekyll-cache/
.jekyll-metadata
# Ignore folders generated by Bundler
.bundle/
vendor/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# MacOS
.DS_Store
# Dataset
data/FashionMNIST/
data/translation2019zh/
data/lcsts_tsv/
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# UMT5
**注:执行下游任务是需要先进行预训练, 训练代码参考train_model.py。**
<div align="center">
<img align="center" src=docs/T5_task.png>
</div>
## 论文
- [论文地址] [UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining](https://arxiv.org/abs/2304.09151)
## 模型结构
umT5:T5 的多语言版本,具备 T5 模型大部分的多功能性,在多语言通用爬虫语料库 mC4 上预训练,覆盖 101 种语言;Encoder-Decoder架构,编码层和解码层都是12层,一共有220M个参数,大概是bert-base 的两倍。
### MT5 模型结构
<div align="center">
<img align="center" src=docs/T5_structure.png>
</div>
## 算法原理
总的来说,mT5 跟 T5 一脉相承的,整体基本一样,但在模型结构方面,mT5 用的是 T5.1.1方案,在此对它做个基本的介绍。
它主要的改动来自论文[GLU Variants Improve Transformer](https://arxiv.org/abs/2002.05202),主要是借用了[Language Modeling with Gated Convolutional Networks](https://arxiv.org/abs/1612.08083)**GLU**(Gated Linear Unit)来增强 FFN 部分的效果。具体来说,原来 T5 的 FFN 为(T5 没有 Bias):
<div align="center">
<img align="center" src=docs/equation1.png>
</div>
改为:
<div align="center">
<img align="center" src=docs/euqation2.png>
</div>
### T5 Transformer
<div align="center">
<img align="center" src=docs/t5transformer.png>
</div>
## 环境配置
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/2.1.0-ubuntu22.04-dtk23.10.1-py310
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal/:/opt/hyhal/:ro --shm-size=64G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name umt5 <your imageID> bash
cd /path/your_code_data/umt5
pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
```
### Dockerfile(方法二)
```
cd /path/your_code_data/umt5/docker
docker build --no-cache -t umt5:latest .
docker run --shm-size=64G --name fingpt -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video -v /path/your_code_data/:/path/your_code_data/ -it umt5 bash
```
### Anaconda(方法三)
```
DTK驱动:dtk23.10
python:python3.10
torch:2.1
peft: >=0.8.0
transformers: >=4.28.1
protobuf: >=3.20
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应`
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
conda create -n umt5 python=3.10
cd /path/your_code_data/umt5
pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple
```
## 数据集
我们选择大规模中文短文本摘要语料库 LCSTS 作为数据集,该语料基于新浪微博短新闻构建,规模超过 200 万。
### 数据处理代码
```
class LCSTS(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
Data = {}
with open(data_file, 'rt', encoding='utf-8') as f:
for idx, line in enumerate(f):
if idx >= max_dataset_size:
break
items = line.strip().split('!=!')
assert len(items) == 2
Data[idx] = {
'title': items[0],
'content': items[1]
}
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
batch_inputs.append(sample['content'])
batch_targets.append(sample['title'])
batch_data = tokenizer(
batch_inputs,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt"
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch_targets,
padding=True,
max_length=max_target_length,
truncation=True,
return_tensors="pt"
)["input_ids"]
batch_data['decoder_input_ids'] = model.module.prepare_decoder_input_ids_from_labels(labels)
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx+1:] = -100
batch_data['labels'] = labels
return batch_data
train_data = LCSTS('/public/home/wanglch/project/umt5/data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('/public/home/wanglch/project/umt5/data/lcsts_tsv/data2.tsv')
```
**数据处理以包含在model_train.py和model_test.py中不用单独运行数据处理代码。**
项目中已提供用于试验训练的迷你数据集,训练数据目录结构如下,用于正常训练的完整数据集请按此目录结构进行制备:
```
── dataset
│   ├── lcsts_tsv
│    ├── data1.tsv
│    ├── data2.tsv
│ └── data3.tsv
│——————————
```
## 训练
### 单机多卡
```
python multi_dcu_train.py
```
## 测试
### 单机多卡
```
python multi_dcu_test.py
```
## 摘要任务
要进行摘要任务需先进行模型训练,从hf-mirror或者huggingface下载umt5-base模型后,使用**multi_dcu_train.py**进行训练,保存训练权重后,加载权重进行摘要处理。同理,若要处理阅读理解,语言翻译任务时也需要做类似操作。
```
python umt5_summary.py
```
## result
### 中文文本摘要任务
<div align="center">
<img align="center" src=docs/result.png>
</div>
### 精度
测试数据:[LCSTS](data\lcsts_tsv),使用的加速卡:V100S/K100。
根据测试结果情况填写表格:
| device | Rougue 1 | Rougue 2 | Rougue L |
| :------: | :------: | :------: | :------: |
| V100s | 26.12 | 14.81 | 23.62 |
| K100 | 26.94 | 15.38 | 24.24 |
## 应用场景
### 算法类别
`文本摘要`
### 热点应用行业
**金融**, **教育**, **政府**, **科研**, **制造**, **能源**, **广媒**
## 预训练权重
- [hf-mirror预训练模型下载地址](https://hf-mirror.com/google/umt5-base/tree/main)
- [hf-mirror umt5预训练模型下载地址](https://hf-mirror.com/collections/google/mt5-release-65005f1a520f8d7b4d039509)
## 源码仓库问题反馈
- http://developer.hpccube.com/codes/modelzoo/fingpt.git
## 参考资料
- [google-research/multilingual-t5](https://github.com/google-research/multilingual-t5)
- [UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining](https://arxiv.org/abs/2304.09151)
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu22.04-dtk23.10.1-py310
ENV DEBIAN_FRONTEND=noninteractive
RUN source /opt/dtk-23.10/env.sh
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
numpy>=1.25.0
tqdm>=4.65.0
peft>=0.8.0
transformers>=4.28.1
accelerate
loralib
evaluate
tqdm
datasets
torch
scikit-learn
protobuf>=3.20
sentencepiece
sacrebleu
rouge
\ No newline at end of file
# 模型唯一标识
modelCode = 626
# 模型名称
modelName=umt5_pytorch
# 模型描述
modelDescription=开源免费可商用的预训练多任务语言模型可用于关键词提取, 翻译,摘要等。
# 应用场景
appScenario=摘要,训练,金融,教育,政府,科研,制造,能源,广媒。
# 框架类型
frameType=pytorch
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AdamW, get_scheduler
from tqdm.auto import tqdm
from rouge import Rouge
import random
import numpy as np
import os
import json
from torch import nn
def seed_everything(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
class LCSTS(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
Data = {}
with open(data_file, 'rt', encoding='utf-8') as f:
for idx, line in enumerate(f):
if idx >= max_dataset_size:
break
items = line.strip().split('!=!')
assert len(items) == 2
Data[idx] = {
'title': items[0],
'content': items[1]
}
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
batch_inputs.append(sample['content'])
batch_targets.append(sample['title'])
batch_data = tokenizer(
batch_inputs,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt"
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch_targets,
padding=True,
max_length=max_target_length,
truncation=True,
return_tensors="pt"
)["input_ids"]
batch_data['decoder_input_ids'] = decoder_input_ids = model.module.prepare_decoder_input_ids_from_labels(labels)
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx+1:] = -100
batch_data['labels'] = labels
return batch_data
if __name__=='__main__':
os.environ["HIP_VISIBLE_DEVICES"] = "4,5"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
seed_everything(5)
max_dataset_size = 200000
max_input_length = 512
max_target_length = 32
batch_size = 16
learning_rate = 1e-5
epoch_num = 3
beam_size = 4
no_repeat_ngram_size = 2
test_data = LCSTS('/umt5/data/lcsts_tsv/data3.tsv')
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=False, collate_fn=collate_fn)
model_checkpoint = "/umt5/utm5_base"
trained_model_weights = '/umt5/saves/train_dtk_weights/epoch_1_valid_rouge_23.4347_model_dtk_weights.bin'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)
# 检查是否有多个 GPU 可用
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# 如果有多个 GPUs,使用 nn.DataParallel 包装模型
model = nn.DataParallel(model)
model.load_state_dict(torch.load(trained_model_weights))
model.eval()
rouge = Rouge()
with torch.no_grad():
print('evaluating on test set...')
sources, preds, labels = [], [], []
for batch_data in tqdm(test_dataloader):
batch_data = {k: v.to(device) for k, v in batch_data.items()} # 将数据移动到设备上
generated_tokens = model.module.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
num_beams=beam_size,
no_repeat_ngram_size=no_repeat_ngram_size,
).cpu().numpy()
if isinstance(generated_tokens, tuple):
generated_tokens = generated_tokens[0]
label_tokens = batch_data["labels"].cpu().numpy()
decoded_sources = tokenizer.batch_decode(
batch_data["input_ids"].cpu().numpy(),
skip_special_tokens=True,
use_source_tokenizer=True
)
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)
sources += [source.strip() for source in decoded_sources]
preds += [pred.strip() for pred in decoded_preds]
labels += [label.strip() for label in decoded_labels]
scores = rouge.get_scores(
hyps=[' '.join(pred) for pred in preds],
refs=[' '.join(label) for label in labels],
avg=True
)
rouges = {key: value['f'] * 100 for key, value in scores.items()}
rouges['avg'] = np.mean(list(rouges.values()))
print(f"Test Rouge1: {rouges['rouge-1']:>0.2f} Rouge2: {rouges['rouge-2']:>0.2f} RougeL: {rouges['rouge-l']:>0.2f}\n")
results = []
print('saving predicted results...')
for source, pred, label in zip(sources, preds, labels):
results.append({
"document": source,
"prediction": pred,
"summarization": label
})
with open('test_data_pred.json', 'wt', encoding='utf-8') as f:
for exapmle_result in results:
f.write(json.dumps(exapmle_result, ensure_ascii=False) + '\n')
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import AdamW, get_scheduler
from tqdm.auto import tqdm
from rouge import Rouge
import random
import numpy as np
import os
import json
from torch import nn
from datetime import datetime
def seed_everything(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
class LCSTS(Dataset):
def __init__(self, data_file):
self.data = self.load_data(data_file)
def load_data(self, data_file):
Data = {}
with open(data_file, 'rt', encoding='utf-8') as f:
for idx, line in enumerate(f):
if idx >= max_dataset_size:
break
items = line.strip().split('!=!')
assert len(items) == 2
Data[idx] = {
'title': items[0],
'content': items[1]
}
return Data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(batch_samples):
batch_inputs, batch_targets = [], []
for sample in batch_samples:
batch_inputs.append(sample['content'])
batch_targets.append(sample['title'])
batch_data = tokenizer(
batch_inputs,
padding=True,
max_length=max_input_length,
truncation=True,
return_tensors="pt"
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
batch_targets,
padding=True,
max_length=max_target_length,
truncation=True,
return_tensors="pt"
)["input_ids"]
batch_data['decoder_input_ids'] = model.module.prepare_decoder_input_ids_from_labels(labels)
end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
for idx, end_idx in enumerate(end_token_index):
labels[idx][end_idx+1:] = -100
batch_data['labels'] = labels
return batch_data
def train_loop(dataloader, model, optimizer, lr_scheduler, epoch, total_loss):
progress_bar = tqdm(range(len(dataloader)))
progress_bar.set_description(f'loss: {0:>7f}')
finish_batch_num = (epoch-1) * len(dataloader)
model.train()
for batch, batch_data in enumerate(dataloader, start=1):
batch_data = batch_data.to(device)
outputs = model(**batch_data)
loss = outputs.loss
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
lr_scheduler.step()
total_loss += loss.item()
progress_bar.set_description(f'loss: {total_loss/(finish_batch_num + batch):>7f}')
progress_bar.update(1)
return total_loss
def test_loop(dataloader, model):
preds, labels = [], []
model.eval()
for batch_data in tqdm(dataloader):
batch_data = batch_data.to(device)
with torch.no_grad():
# 如果你使用了 DataParallel,你可以通过访问 model.module 来获取原始模型
generated_tokens = model.module.generate(
batch_data["input_ids"],
attention_mask=batch_data["attention_mask"],
max_length=max_target_length,
num_beams=beam_size,
no_repeat_ngram_size=no_repeat_ngram_size,
).cpu().numpy()
if isinstance(generated_tokens, tuple):
generated_tokens = generated_tokens[0]
label_tokens = batch_data["labels"].cpu().numpy()
decoded_preds = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
label_tokens = np.where(label_tokens != -100, label_tokens, tokenizer.pad_token_id)
decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)
preds += [' '.join(pred.strip()) for pred in decoded_preds]
labels += [' '.join(label.strip()) for label in decoded_labels]
scores = rouge.get_scores(hyps=preds, refs=labels, avg=True)
result = {key: value['f'] * 100 for key, value in scores.items()}
result['avg'] = np.mean(list(result.values()))
print(f"Rouge1: {result['rouge-1']:>0.2f} Rouge2: {result['rouge-2']:>0.2f} RougeL: {result['rouge-l']:>0.2f}\n")
return result
if __name__=='__main__':
os.environ["HIP_VISIBLE_DEVICES"] = "4,5"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')
seed_everything(5)
rouge = Rouge()
max_dataset_size = 200000
max_input_length = 512
max_target_length = 32
batch_size = 16
learning_rate = 1e-5
epoch_num = 1
beam_size = 4
no_repeat_ngram_size = 2
train_data = LCSTS('/umt5/data/lcsts_tsv/data1.tsv')
valid_data = LCSTS('/umt5/data/lcsts_tsv/data2.tsv')
model_checkpoint = "/umt5/umt5_base"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
# 检查是否有多个 GPU 可用
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
# 如果有多个 GPUs,使用 nn.DataParallel 包装模型
model = nn.DataParallel(model).to(device)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
valid_dataloader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
"linear",
optimizer=optimizer,
num_warmup_steps=0,
num_training_steps=epoch_num*len(train_dataloader),
)
total_loss = 0.
best_avg_rouge = 0.
for t in range(epoch_num):
print(f"Epoch {t+1}/{epoch_num}\n-------------------------------")
total_loss = train_loop(train_dataloader, model, optimizer, lr_scheduler, t+1, total_loss)
valid_rouge = test_loop(valid_dataloader, model)
rouge_avg = valid_rouge['avg']
if rouge_avg > best_avg_rouge:
best_avg_rouge = rouge_avg
print('saving new weights...\n')
weight_path = f'/utm5/saves/train_dtk_weights/epoch_{t+1}_valid_rouge_{rouge_avg:0.4f}_model_dtk_weights.bin'
torch.save(model.state_dict(), weight_path)
# 加载训练后的权重
state_dict = torch.load(weight_path)
model.load_state_dict(state_dict)
# 获取当前的日期和时间
now = datetime.now()
timestamp = now.strftime("%Y%m%d_%H%M%S")
new_model_path = f'saves/umt5_{timestamp}'
model.module.save_pretrained(new_model_path)
tokenizer.save_pretrained(new_model_path)
print("Done!")
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution;
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to <https://cla.developers.google.com/> to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---
language:
- multilingual
- af
- am
- ar
- az
- be
- bg
- bn
- ca
- ceb
- co
- cs
- cy
- da
- de
- el
- en
- eo
- es
- et
- eu
- fa
- fi
- fil
- fr
- fy
- ga
- gd
- gl
- gu
- ha
- haw
- hi
- hmn
- ht
- hu
- hy
- ig
- is
- it
- iw
- ja
- jv
- ka
- kk
- km
- kn
- ko
- ku
- ky
- la
- lb
- lo
- lt
- lv
- mg
- mi
- mk
- ml
- mn
- mr
- ms
- mt
- my
- ne
- nl
- no
- ny
- pa
- pl
- ps
- pt
- ro
- ru
- sd
- si
- sk
- sl
- sm
- sn
- so
- sq
- sr
- st
- su
- sv
- sw
- ta
- te
- tg
- th
- tr
- uk
- und
- ur
- uz
- vi
- xh
- yi
- yo
- zh
- zu
datasets:
- mc4
license: apache-2.0
---
[Google's UMT5](https://github.com/google-research/multilingual-t5)
UMT5 is pretrained on the an updated version of [mC4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual) corpus, covering 107 languages:
Afrikaans, Albanian, Amharic, Arabic, Armenian, Azerbaijani, Basque, Belarusian, Bengali, Bulgarian, Burmese, Catalan, Cebuano, Chichewa, Chinese, Corsican, Czech, Danish, Dutch, English, Esperanto, Estonian, Filipino, Finnish, French, Galician, Georgian, German, Greek, Gujarati, Haitian Creole, Hausa, Hawaiian, Hebrew, Hindi, Hmong, Hungarian, Icelandic, Igbo, Indonesian, Irish, Italian, Japanese, Javanese, Kannada, Kazakh, Khmer, Korean, Kurdish, Kyrgyz, Lao, Latin, Latvian, Lithuanian, Luxembourgish, Macedonian, Malagasy, Malay, Malayalam, Maltese, Maori, Marathi, Mongolian, Nepali, Norwegian, Pashto, Persian, Polish, Portuguese, Punjabi, Romanian, Russian, Samoan, Scottish Gaelic, Serbian, Shona, Sindhi, Sinhala, Slovak, Slovenian, Somali, Sotho, Spanish, Sundanese, Swahili, Swedish, Tajik, Tamil, Telugu, Thai, Turkish, Ukrainian, Urdu, Uzbek, Vietnamese, Welsh, West Frisian, Xhosa, Yiddish, Yoruba, Zulu.
**Note**: UMT5 was only pre-trained on mC4 excluding any supervised training. Therefore, this model has to be fine-tuned before it is useable on a downstream task.
Pretraining Dataset: [mC4](https://www.tensorflow.org/datasets/catalog/c4#c4multilingual)
Other Community Checkpoints: [here](https://huggingface.co/models?search=umt5)
Paper: [UniMax, Fairer and More Effective Language Sampling for Large-Scale Multilingual Pretraining](https://openreview.net/forum?id=kXwdL1cWOAi)
Authors: *by Hyung Won Chung, Xavier Garcia, Adam Roberts, Yi Tay, Orhan Firat, Sharan Narang, Noah Constant*
## Abstract
*Pretrained multilingual large language models have typically used heuristic temperature-based sampling to balance between different languages. However previous work has not systematically evaluated the efficacy of different pretraining language distributions across model scales. In this paper, we propose a new sampling method, UniMax, that delivers more uniform coverage of head languages while mitigating overfitting on tail languages by explicitly capping the number of repeats over each language's corpus. We perform an extensive series of ablations testing a range of sampling strategies on a suite of multilingual benchmarks, while varying model scale. We find that UniMax outperforms standard temperature-based sampling, and the benefits persist as scale increases. As part of our contribution, we release: (i) an improved and refreshed mC4 multilingual corpus consisting of 29 trillion characters across 107 languages, and (ii) a suite of pretrained umT5 model checkpoints trained with UniMax sampling.*
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment