Commit 11aad6fa authored by yongshk's avatar yongshk
Browse files

add new

parents
Pipeline #576 canceled with stages
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# 算法名简写(英文简写大写)
## 论文
`Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context`
- https://arxiv.org/abs/1901.02860
## 模型结构
TransformersXL 是一种改进的 Transformer 模型,旨在处理更长的文本序列。它引入了**延展性机制**,通过**分块处理**超长序列,然后使用**跨块注意力**来捕捉长距离依赖关系。
![img](doc\模型结构.png)
## 算法原理
Transformer-XL 在很大程度上依赖于普通 Transformer(Al-Rfou 等人),但引入了两种创新技术——**递归机制****相对位置编码**——来克服普通 Transformer 的缺点以下是其原理对比
transformer
![](doc\transformer的训练与评估.png)
transformer-XL
![img](doc\xl的训练与评估.png)
## 环境配置
### Docker(方法一)
此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取docker镜像的地址与使用步骤
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10-py37-latest
docker run -it --network=host --name=transformer-XL --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=32G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 -v /root/transformerxl:/home image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10-py37-latest
```
### Anaconda(方法二)
此处提供本地配置、编译的详细步骤,例如:
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk22.10
python:python3.7
```
`Tips:以上dtk驱动、python等DCU相关工具版本需要严格一一对应`
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt
```
## 数据集
`enwik8`
- http://mattmahoney.net/dc/enwik8.zip
此处提供数据预处理脚本的使用方法
```
wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
python3 prep_enwik8.py
```
项目中已提供用于试验训练的迷你数据集,训练数据目录结构如下,用于正常训练的完整数据集请按此目录结构进行制备:
```
── data
│   ├── train.txt
│   └── vaild.txt
│   └── test.txt
```
## 训练
### 单机多卡
```
sh run_enwik8_base_dp.sh train
```
### 单机单卡
```
sh run_enwik8_base.sh train
```
## 推理
```
sh run_enwik8_base.sh eval --work_dir 模型路径
```
## result
![rusult](doc\rusult.png)
### 精度
测试数据:[test data](http://mattmahoney.net/dc/enwik8.zip),使用的加速卡:Z100L。
根据测试结果情况填写表格:
| transformer-XL | loss | bpc |
| :------: | :------: | :------: |
| enwik8 | 0.9 | 1.292 |
## 应用场景
### 算法类别
`nlp、长序列处理`
### 热点应用行业
`自然语言生成``机器翻译``长文本分类``对话系统`
## 源码仓库及问题反馈
- 此处填本项目gitlab地址
## 参考资料
- https://github.com/kimiyoung/transformer-xl
119 119 46 97 117 98 117 114 110 46 101 100 117 47 97 99 97 100 101 109 105 99 47 108 105 98 101 114 97 108 95 97 114 116 115 47 102 111 114 101 105 103 110 47 114 117 115 115 105 97 110 47 105 99 111 110 115 47 32 82 117 115 115 105 97 110 32 73 99 111 110 115 32 102 114 111 109 32 49 50 116 104 32 116 111 32 49 56 116 104 32 99 101 110 116 117 114 121 93
\ No newline at end of file
60 109 101 100 105 97 119 105 107 105 32 120 109 108 110 115 61 34 104 116 116 112 58 47 47 119 119 119 46 109 101 100 105 97 119 105 107 105 46 111 114 103 47 120 109 108 47 101 120 112 111 114 116 45 48 46 51 47 34 32 120 109 108 110 115 58 120 115 105 61 34 104 116 116 112 58 47 47 119 119 119 46 119 51 46 111 114 103 47 50 48 48 49 47 88 77 76 83 99 104 101 109 97 45 105 110 115 116 97 110 99 101 34 32 120 115 105 58 115 99 104 101 109 97 76 111 99 97 116 105 111 110 61 34 104 116 116 112 58 47 47 119 119 119 46 109 101 100 105 97 119 105 107 105 46 111 114 103 47 120 109 108 47 101 120 112 111 114 116 45 48 46 51 47 32 104 116 116 112 58 47 47 119 119 119 46 109 101 100 105 97 119 105 107 105 46 111 114 103 47 120 109 108 47 101 120 112 111 114 116 45 48 46 51 46 120 115 100 34 32 118 101 114 115 105 111 110 61 34 48 46 51 34 32 120 109 108 58 108 97 110 103 61 34 101 110 34 62
32 32 60 115 105 116 101 105 110 102 111 62
32 32 32 32 60 115 105 116 101 110 97 109 101 62 87 105 107 105 112 101 100 105 97 60 47 115 105 116 101 110 97 109 101 62
32 32 32 32 60 98 97 115 101 62 104 116 116 112 58 47 47 101 110 46 119 105 107 105 112 101 100 105 97 46 111 114 103 47 119 105 107 105 47 77 97 105 110 95 80 97 103 101 60 47 98 97 115 101 62
32 32 32 32 60 103 101 110 101 114 97 116 111 114 62 77 101 100 105 97 87 105 107 105 32 49 46 54 97 108 112 104 97 60 47 103 101 110 101 114 97 116 111 114 62
\ No newline at end of file
101 114 110 32 104 105 103 104 108 97 110 100 115 32 97 110 100 32 100 111 101 115 32 110 111 116 32 112 97 114 116 97 107 101 32 111 102 32 97 110 121 32 111 116 104 101 114 32 97 114 101 97 39 115 32 99 114 111 112 115 46 32 32 84 104 101 32 109 111 115 116 32 102 97 109 111 117 115 32 109 101 109 98 101 114 32 111 102 32 116 104 105 115 32 99 114 111 112 32 115 121 115 116 101 109 32 105 115 32 91 91 99 111 102 102 101 101 93 93 44 32 98 117 116 32 111 110 101 32 111 102 32 116 104 101 32 109 111 114 101 32 117 115 101 102 117 108 32 112 108 97 110 116 115 32 105 115 32 91 91 115 111 114 103 104 117 109 93 93 44 32 97 32 100 114 121 45 108 97 110 100 32 103 114 97 105 110 46
65 110 99 105 101 110 116 32 99 117 108 116 117 114 101 115 32 97 108 115 111 32 101 120 105 115 116 101 100 32 97 108 108 32 97 108 111 110 103 32 116 104 101 32 91 91 78 105 108 101 93 93 44 32 97 110 100 32 105 110 32 109 111 100 101 114 110 45 100 97 121 32 91 91 71 104 97 110 97 93 93 32 38 108 116 59 33 45 45 32 97 110 100 32 109 117 99 104 32 109 111 114 101 46 46 46 32 45 45 38 103 116 59 46
echo "=== Acquiring datasets ==="
echo "---"
mkdir -p data
cd data
if [[ ! -d 'wikitext-2' ]]; then
echo "- Downloading WikiText-2 (WT2)"
wget --quiet --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip
unzip -q wikitext-2-v1.zip
cd wikitext-2
mv wiki.train.tokens train.txt
mv wiki.valid.tokens valid.txt
mv wiki.test.tokens test.txt
cd ..
fi
echo "- Downloading WikiText-103 (WT2)"
if [[ ! -d 'wikitext-103' ]]; then
wget --continue https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip
unzip -q wikitext-103-v1.zip
cd wikitext-103
mv wiki.train.tokens train.txt
mv wiki.valid.tokens valid.txt
mv wiki.test.tokens test.txt
cd ..
fi
echo "- Downloading enwik8 (Character)"
if [[ ! -d 'enwik8' ]]; then
mkdir -p enwik8
cd enwik8
wget --continue http://mattmahoney.net/dc/enwik8.zip
wget https://raw.githubusercontent.com/salesforce/awd-lstm-lm/master/data/enwik8/prep_enwik8.py
python3 prep_enwik8.py
cd ..
fi
echo "- Downloading text8 (Character)"
if [[ ! -d 'text8' ]]; then
mkdir -p text8
cd text8
wget --continue http://mattmahoney.net/dc/text8.zip
python ../../prep_text8.py
cd ..
fi
echo "- Downloading Penn Treebank (PTB)"
if [[ ! -d 'penn' ]]; then
wget --quiet --continue http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz
tar -xzf simple-examples.tgz
mkdir -p penn
cd penn
mv ../simple-examples/data/ptb.train.txt train.txt
mv ../simple-examples/data/ptb.test.txt test.txt
mv ../simple-examples/data/ptb.valid.txt valid.txt
cd ..
echo "- Downloading Penn Treebank (Character)"
mkdir -p pennchar
cd pennchar
mv ../simple-examples/data/ptb.char.train.txt train.txt
mv ../simple-examples/data/ptb.char.test.txt test.txt
mv ../simple-examples/data/ptb.char.valid.txt valid.txt
cd ..
rm -rf simple-examples/
fi
echo "- Downloading 1B words"
if [[ ! -d 'one-billion-words' ]]; then
mkdir -p one-billion-words
cd one-billion-words
wget --no-proxy http://www.statmt.org/lm-benchmark/1-billion-word-language-modeling-benchmark-r13output.tar.gz
tar xzvf 1-billion-word-language-modeling-benchmark-r13output.tar.gz
path="1-billion-word-language-modeling-benchmark-r13output/heldout-monolingual.tokenized.shuffled/"
cat ${path}/news.en.heldout-00000-of-00050 > valid.txt
cat ${path}/news.en.heldout-00000-of-00050 > test.txt
wget https://github.com/rafaljozefowicz/lm/raw/master/1b_word_vocab.txt
cd ..
fi
echo "---"
echo "Happy language modeling :)"
#!/usr/bin/env python
# coding=utf-8
import os
import sys
import zipfile
from io import open
if os.path.exists('train.txt'):
print('Tokenized text8 already exists - skipping processing')
sys.exit()
data = zipfile.ZipFile('text8.zip').extractall()
data = open('text8', 'r', encoding='utf-8').read()
print('Length of text8: {}'.format(len(data)))
num_test_chars = 5000000
train_data = data[: -2 * num_test_chars]
valid_data = data[-2 * num_test_chars: -num_test_chars]
test_data = data[-num_test_chars:]
for fn, part in [('train.txt', train_data), ('valid.txt', valid_data), ('test.txt', test_data)]:
print('{} will have {} bytes'.format(fn, len(part)))
print('- Tokenizing...')
# Change space ' ' to underscore '_'
part_str = ' '.join(['_' if c == ' ' else c for c in part.strip()])
print('- Writing...')
f = open(fn, 'w').write(part_str)
f = open(fn + '.raw', 'w', encoding='utf-8').write(part)
import os, sys
import glob
from collections import Counter, OrderedDict
import numpy as np
import torch
from utils.vocabulary import Vocab
class LMOrderedIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None):
"""
data -- LongTensor -- the LongTensor is strictly ordered
"""
self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0
self.device = device
# Work out how cleanly we can divide the dataset into bsz parts.
self.n_step = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, self.n_step * bsz)
# Evenly divide the data across the bsz batches.
self.data = data.view(bsz, -1).t().contiguous().to(device)
# Number of mini-batches
self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
def get_batch(self, i, bptt=None):
if bptt is None: bptt = self.bptt
seq_len = min(bptt, self.data.size(0) - 1 - i)
end_idx = i + seq_len
beg_idx = max(0, i - self.ext_len)
data = self.data[beg_idx:end_idx]
target = self.data[i+1:i+1+seq_len]
return data, target, seq_len
def get_fixlen_iter(self, start=0):
for i in range(start, self.data.size(0) - 1, self.bptt):
yield self.get_batch(i)
def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
max_len = self.bptt + max_deviation * std
i = start
while True:
bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.
bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
data, target, seq_len = self.get_batch(i, bptt)
i += seq_len
yield data, target, seq_len
if i >= self.data.size(0) - 2:
break
def __iter__(self):
return self.get_fixlen_iter()
class LMShuffledIterator(object):
def __init__(self, data, bsz, bptt, device='cpu', ext_len=None, shuffle=False):
"""
data -- list[LongTensor] -- there is no order among the LongTensors
"""
self.data = data
self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0
self.device = device
self.shuffle = shuffle
def get_sent_stream(self):
# index iterator
epoch_indices = np.random.permutation(len(self.data)) if self.shuffle \
else np.array(range(len(self.data)))
# sentence iterator
for idx in epoch_indices:
yield self.data[idx]
def stream_iterator(self, sent_stream):
# streams for each data in the batch
streams = [None] * self.bsz
data = torch.LongTensor(self.bptt, self.bsz)
target = torch.LongTensor(self.bptt, self.bsz)
n_retain = 0
while True:
# data : [n_retain+bptt x bsz]
# target : [bptt x bsz]
data[n_retain:].fill_(-1)
target.fill_(-1)
valid_batch = True
for i in range(self.bsz):
n_filled = 0
try:
while n_filled < self.bptt:
if streams[i] is None or len(streams[i]) <= 1:
streams[i] = next(sent_stream)
# number of new tokens to fill in
n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
# first n_retain tokens are retained from last batch
data[n_retain+n_filled:n_retain+n_filled+n_new, i] = \
streams[i][:n_new]
target[n_filled:n_filled+n_new, i] = \
streams[i][1:n_new+1]
streams[i] = streams[i][n_new:]
n_filled += n_new
except StopIteration:
valid_batch = False
break
if not valid_batch:
return
data = data.to(self.device)
target = target.to(self.device)
yield data, target, self.bptt
n_retain = min(data.size(0), self.ext_len)
if n_retain > 0:
data[:n_retain] = data[-n_retain:]
data.resize_(n_retain + self.bptt, data.size(1))
def __iter__(self):
# sent_stream is an iterator
sent_stream = self.get_sent_stream()
for batch in self.stream_iterator(sent_stream):
yield batch
class LMMultiFileIterator(LMShuffledIterator):
def __init__(self, paths, vocab, bsz, bptt, device='cpu', ext_len=None,
shuffle=False):
self.paths = paths
self.vocab = vocab
self.bsz = bsz
self.bptt = bptt
self.ext_len = ext_len if ext_len is not None else 0
self.device = device
self.shuffle = shuffle
def get_sent_stream(self, path):
sents = self.vocab.encode_file(path, add_double_eos=True)
if self.shuffle:
np.random.shuffle(sents)
sent_stream = iter(sents)
return sent_stream
def __iter__(self):
if self.shuffle:
np.random.shuffle(self.paths)
for path in self.paths:
# sent_stream is an iterator
sent_stream = self.get_sent_stream(path)
for batch in self.stream_iterator(sent_stream):
yield batch
class Corpus(object):
def __init__(self, path, dataset, *args, **kwargs):
self.dataset = dataset
self.vocab = Vocab(*args, **kwargs)
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
self.vocab.count_file(os.path.join(path, 'train.txt'))
self.vocab.count_file(os.path.join(path, 'valid.txt'))
self.vocab.count_file(os.path.join(path, 'test.txt'))
elif self.dataset == 'wt103':
self.vocab.count_file(os.path.join(path, 'train.txt'))
elif self.dataset == 'lm1b':
train_path_pattern = os.path.join(
path, '1-billion-word-language-modeling-benchmark-r13output',
'training-monolingual.tokenized.shuffled', 'news.en-*')
train_paths = glob.glob(train_path_pattern)
# the vocab will load from file when build_vocab() is called
self.vocab.build_vocab()
if self.dataset in ['ptb', 'wt2', 'wt103']:
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True)
elif self.dataset in ['enwik8', 'text8']:
self.train = self.vocab.encode_file(
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
elif self.dataset == 'lm1b':
self.train = train_paths
self.valid = self.vocab.encode_file(
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
self.test = self.vocab.encode_file(
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
def get_iterator(self, split, *args, **kwargs):
if split == 'train':
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
data_iter = LMOrderedIterator(self.train, *args, **kwargs)
elif self.dataset == 'lm1b':
kwargs['shuffle'] = True
data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
elif split in ['valid', 'test']:
data = self.valid if split == 'valid' else self.test
if self.dataset in ['ptb', 'wt2', 'wt103', 'enwik8', 'text8']:
data_iter = LMOrderedIterator(data, *args, **kwargs)
elif self.dataset == 'lm1b':
data_iter = LMShuffledIterator(data, *args, **kwargs)
return data_iter
def get_lm_corpus(datadir, dataset):
fn = os.path.join(datadir, 'cache.pt')
if os.path.exists(fn):
print('Loading cached dataset...')
corpus = torch.load(fn)
else:
print('Producing dataset {}...'.format(dataset))
kwargs = {}
if dataset in ['wt103', 'wt2']:
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = False
elif dataset == 'ptb':
kwargs['special'] = ['<eos>']
kwargs['lower_case'] = True
elif dataset == 'lm1b':
kwargs['special'] = []
kwargs['lower_case'] = False
kwargs['vocab_file'] = os.path.join(datadir, '1b_word_vocab.txt')
elif dataset in ['enwik8', 'text8']:
pass
corpus = Corpus(datadir, dataset, **kwargs)
torch.save(corpus, fn)
return corpus
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='unit test')
parser.add_argument('--datadir', type=str, default='../data/text8',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='text8',
choices=['ptb', 'wt2', 'wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
args = parser.parse_args()
corpus = get_lm_corpus(args.datadir, args.dataset)
print('Vocab size : {}'.format(len(corpus.vocab.idx2sym)))
# coding: utf-8
import argparse
import time
import math
import os, sys
import torch
from data_utils import get_lm_corpus
from mem_transformer import MemTransformerLM
from utils.exp_utils import get_logger
from torchinfo import summary
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
parser.add_argument('--data', type=str, default='../data/wikitext-103',
help='location of the data corpus')
parser.add_argument('--dataset', type=str, default='wt103',
choices=['wt103', 'lm1b', 'enwik8', 'text8'],
help='dataset name')
parser.add_argument('--split', type=str, default='all',
choices=['all', 'valid', 'test'],
help='which split to evaluate')
parser.add_argument('--batch_size', type=int, default=10,
help='batch size')
parser.add_argument('--tgt_len', type=int, default=5,
help='number of tokens to predict')
parser.add_argument('--ext_len', type=int, default=0,
help='length of the extended context')
parser.add_argument('--mem_len', type=int, default=0,
help='length of the retained previous heads')
parser.add_argument('--clamp_len', type=int, default=-1,
help='max positional embedding index')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative'
device = torch.device("cuda" if args.cuda else "cpu")
# Get logger
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
log_=not args.no_log)
# Load dataset
corpus = get_lm_corpus(args.data, args.dataset)
ntokens = len(corpus.vocab)
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
# Load the best saved model.
with open(os.path.join(args.work_dir, 'model.pt'), 'rb') as f:
model = torch.load(f)
model.backward_compatible()
model = model.to(device)
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0:
model.clamp_len = args.clamp_len
if args.same_length:
model.same_length = True
###############################################################################
# Evaluation code
###############################################################################
def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout.
model.eval()
total_len, total_loss = 0, 0.
start_time = time.time()
with torch.no_grad():
mems = tuple()
for idx, (data, target, seq_len) in enumerate(eval_iter):
ret = model(data, target, *mems)
loss, mems = ret[0], ret[1:]
loss = loss.mean()
total_loss += seq_len * loss.item()
total_len += seq_len
total_time = time.time() - start_time
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len
# Run on test data.
if args.split == 'all':
test_loss = evaluate(te_iter)
valid_loss = evaluate(va_iter)
elif args.split == 'valid':
valid_loss = evaluate(va_iter)
test_loss = None
elif args.split == 'test':
test_loss = evaluate(te_iter)
valid_loss = None
def format_log(loss, split):
if args.dataset in ['enwik8', 'text8']:
log_str = '| {0} loss {1:5.2f} | {0} bpc {2:9.5f} '.format(
split, loss, loss / math.log(2))
else:
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format(
split, loss, math.exp(loss))
return log_str
log_str = ''
if valid_loss is not None:
log_str += format_log(valid_loss, 'valid')
if test_loss is not None:
log_str += format_log(test_loss, 'test')
logging('=' * 100)
logging(log_str)
logging('=' * 100)
This diff is collapsed.
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python -u train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 12 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 2048 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 400000 \
--tgt_len 512 \
--mem_len 512 \
--eval_tgt_len 128 \
--batch_size 16 \
--fp16 \
--cuda \
# --batch_size 22 \
# --multi_gpu \
# --gpu0_bsz 4 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python -u eval.py \
--work_dir ../LM-TFM-enwik8/20230913-141909 \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 80 \
--mem_len 2100 \
--clamp_len 820 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 12 \
--d_model 512 \
--n_head 8 \
--d_head 64 \
--d_inner 2048 \
--dropout 0.1 \
--dropatt 0.0 \
--optim adam \
--lr 0.00025 \
--warmup_step 0 \
--max_step 400000 \
--tgt_len 512 \
--mem_len 512 \
--eval_tgt_len 128 \
--batch_size 22 \
--multi_gpu \
--gpu0_bsz 4 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 80 \
--mem_len 2100 \
--clamp_len 820 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--n_layer 24 \
--d_model 1024 \
--n_head 8 \
--d_head 128 \
--d_inner 3072 \
--dropout 0.15 \
--dropatt 0.15 \
--optim adam \
--lr 0.00025 \
--warmup_step 4000 \
--max_step 400000 \
--tgt_len 768 \
--mem_len 768 \
--eval_tgt_len 128 \
--batch_size 64 \
--multi_gpu \
--gpu0_bsz 0 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/enwik8/ \
--dataset enwik8 \
--tgt_len 128 \
--mem_len 3800 \
--clamp_len 1000 \
--same_length \
--split test \
${@:2}
else
echo 'unknown argment 1'
fi
#!/bin/bash
if [[ $1 == 'train' ]]; then
echo 'Run training...'
python train.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--adaptive \
--n_layer 18 \
--d_model 1024 \
--div_val 4 \
--n_head 8 \
--d_head 128 \
--d_inner 4096 \
--dropout 0.0 \
--dropatt 0.0 \
--optim adam \
--warmup_step 20000 \
--max_step 500000 \
--lr 0.00025 \
--tgt_len 32 \
--mem_len 32 \
--eval_tgt_len 32 \
--batch_size 224 \
--multi_gpu \
--gpu0_bsz 32 \
${@:2}
elif [[ $1 == 'eval' ]]; then
echo 'Run evaluation...'
python eval.py \
--cuda \
--data ../data/one-billion-words/ \
--dataset lm1b \
--batch_size 64 \
--tgt_len 32 \
--mem_len 128 \
--split test \
--same_length \
${@:2}
else
echo 'unknown argment 1'
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment