Commit 0371621a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1989 canceled with stages
# Override jupyter in Github language stats for more accurate estimate of repo code languages
# reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code
*.ipynb linguist-generated
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# Wandb
wandb/
# Gradio
flagged/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
cff-version: 1.2.0
message: "If you use this software, please cite it as below."
authors:
- family-names: "Ociepa"
given-names: "Krzysztof"
orcid: "https://orcid.org/0009-0006-4858-6460"
title: "ALLaMo: A Simple, Hackable, and Fast Framework for Training Medium-Sized LLMs"
version: 5.0.0
date-released: 2024-08-18
url: "https://github.com/chrisociepa/allamo"
MIT License
Copyright (c) 2023 Chris
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# Llama
彻底开源预训练大模型,极简讲清原理,极简提供算法实现代码,让大家秒懂、秒实现,方便大家提出自研算法,让人类文明自由共享是终极目标。
目前各种SOTA NLP大模型算法都与Llama高度相似,故Llama适合作为算法研发的蓝底,本项目首次从数据集、预训练到调优完全开源大模型算法代码,方便全世界所有算法研究人员共同研究以促进人类文明进步。
<div align=center>
<img src="./doc/llm.png"/>
</div>
## 论文
`Open and Efficient Foundation Language Models`
- https://arxiv.org/pdf/2302.13971
## 模型结构
Llama系列采用极简Decoder-only结构,Llama源自基本的transformer结构,主体为attention(QKV自点积)+ffn(全连接),最后外加一个softmax进行概率转换输出即可,为了使数据分布归一化方便训练收敛,在attention、ffn、softmax前分别再加一个RMS Norm。
总体而言,Llama系列模型结构高度相似,llama1在GPT基础上引入旋转矩阵解决之前绝对位置编码复杂的问题,引入RMSNorm解决LayerNorm计算量大的问题,llama2在llama1的基础上引入GQA进一步减小计算量, llama3在llama2的基础上引入蒸馏、剪枝再进一步减小计算量,模型中其它模块(如flash-attn2、KV cache)只是增加训练推理效率的模块,本项目兼容Llama1-3,以下分别为读者提供Llama的简图和详图帮助读者全方位理解。
<div align=center>
<img src="./doc/llama3.png"/>
</div>
<div align=center>
<img src="./doc/llama3_detail.png"/>
</div>
Facebook官网最原始的llama3请参考代码:[`Llama3`](./allamo/model/model_llama3.py)
## 算法原理
整个Llama算法都体现出大道至简的思想。
原理采用极简的纯矩阵计算,llama将输入embedding(将语句根据词汇量和词的位置、属性转换成数字化矩阵)后放入attention+ffn等提取特征,最后利用Softmax将解码器最后一层产生的未经归一化的分数向量(logits)转换为概率分布,其中每个元素表示生成对应词汇的概率,这使得模型可以生成一个分布,并从中选择最可能的词作为预测结果,然后一个字一个预测出来就是咱们看到的对话生成效果。
损失函数采用最简单方便的CE(cross entropy) loss便可。
<div align=center>
<img src="./doc/algorithm.png"/>
</div>
## 环境配置
```
mv allamo_pytorch allamo # 去框架名后缀
```
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.2-py3.10
# <your IMAGE ID>为以上拉取的docker的镜像ID替换,本镜像为:83714c19d308
docker run -it --shm-size=64G -v $PWD/allamo:/home/allamo -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name llama <your IMAGE ID> bash
cd /home/allamo
pip install -e . #安装allamo库
```
### Dockerfile(方法二)
```
cd cd /home/allamo/docker
docker build --no-cache -t llama:latest .
docker run --shm-size=64G --name llama -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video -v $PWD/../../allamo:/home/allamo -it llama bash
# 若遇到Dockerfile启动的方式安装环境需要长时间等待,可注释掉里面的pip安装,启动容器后再安装python库:pip install -r requirements.txt。
cd /home/allamo
pip install -e . #安装allamo库
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
- https://developer.hpccube.com/tool/
```
DTK驱动:dtk24.04.2
python:python3.10
torch:2.3.0
torchvision:0.18.1
torchaudio:2.1.2
triton:2.1.0
flash-attn:2.0.4
deepspeed:0.14.2
apex:1.3.0
xformers:0.0.25
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应。`
2、其它非特殊库参照requirements.txt安装
```
cd /home/allamo
pip install -e . #安装allamo库
```
## 数据集
实验性的迷你数据集[`shakespeare`](./data/allamo_1B_dataset/input.txt)源于tinyshakespeare,读者自己的数据集可以模仿`input.txt`进行制作。
- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
数据集在训练之前需要用tokenlizer处理成NLP模型的输入,Facebook官方采用tiktoken库制作tockens便可训练出SOTA模型:[`llama3 tokenizer`](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py)
本项目示例中的tokenlizer采用tiktoken库中的"cl100k_base"对数据集做encoding,读者可以自由选择其他tokenlizer,制作代码为[`prepare.py`](./data/allamo_1B_dataset/prepare.py),可以保存为`.bin``.pt`
```
# 数据集制作方法一
cd data/allamo_1B_dataset
python prepare.py
```
若读者的数据非常大,已经保存成多个txt文件,可以采用[`prepare_datasets`](./scripts/prepare_datasets.py)进行制作,然后将制作结果移动到`data/allamo_1B_dataset/`
```
# 数据集制作方法二
cd /home/allamo/scripts
prepare_datasets.sh
```
代码能力较强的读者也可以选择huggingface开源的其它模型,根据以下Demo自己编写tokenlizer来制作预训练数据,本项目本身支持其它tokenlizer格式的数据,例如`meta-llama/Llama-3.2-3B``Qwen/Qwen2.5-1.5B`等小计算量tokenlizer都是较好选择:
```
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_path)
if vocab_size is None:
vocab_size = len(tokenizer)
logger.info(f"HuggingFace {hf_tokenizer_path} tokenizer loaded with the vocab size {vocab_size}")
```
数据制作完成后将allamo中的`data`目录移动到`/home/`下使用,预训练数据的完整目录结构如下:
```
/home/data
├── allamo_1B_dataset
├── input.txt
├── train.bin
├── val.bin
└── prepare.py
└── out-allamo-1B
```
`备注:`本项目灵活度大,仅适于算法基础较好的研究人员使用,对算法基础和代码基础有一定的需求,其它人员可能存在一定的上手门槛,不过,为降低入门难度本项目进一步提供了更简单的预训练模型[`llama3__scratch`](./llama3__scratch)供读者熟悉算法原理和学习,方法:进入其根目录运行`python train.py`
## 训练
### 单机单卡
本项目的最大特点是完全开源、营造自由科研环境,项目中的算法、模型读者可自由修改、研发以提出自己的算法来为社会做贡献,但为了方便介绍,本步骤说明以小规模模型llama3-1B作为示例:
```
关闭 wandb
wandb disabled
wandb offline
cd /home/allamo
python train.py --config="./train_configs/train_1B.json"# 或sh train.sh
# 其它功能正在优化中
```
若希望用原始规模的Llama进行预训练,可修改`train_configs/train_1B.json`中的参数获得新模型配置文件,但大规模模型可能显存超过设备容量,需自行调试合适规模,如`Meta-Llama-3.1-8B`的模型参数可参考(本项目要获得Llama系列模型,dropout都设置为0。):
```
{
"model_args": {
"block_size": 131072,
"vocab_size": 128256,
"rope_freq_base": 10000,
"rope_freq_scale": 1.0,
"n_layer": 32,
"num_kv_heads": 8,
"head_size": 128,
"n_head": 32,
"n_embd": 4096,
"intermediate_size": 14336,
"dropout": 0.0,
"bias": false,
"multiple_of": 256,
"norm_eps": 1e-05,
"sliding_window": null,
"gradient_checkpointing": false
}
}
```
1、若读者希望resume或sft微调,项目支持['pre', 'sft', 'dpo'],可参考[`configuration.py`](./allamo/configuration.py)修改`train_configs/train_1B.json`中的参数自行试用。
2、训练完成后,可通过[`export_to_hf.py`](./scripts/export_to_hf.py)`.pt`格式的权重转换成huggingface的常见格式`.safetensors`进行后续使用或开源发布:
首先,从训练结果`/home/data/out-allamo-1B`中选择希望转换的权重pt及其配置json放到`/home/data/llama3/`下面并重命名:
```
/home/data/llama3
├── config_ckpt.json
└── model_ckpt.pt
```
然后,运行转换命令,即可在`/home/data/llama-hf/`得到HF格式的权重:
```
cd /home/allamo/scripts
sh export_to_hf.sh
```
3、若读者希望在Facebook开源的Llama上微调,也可以通过[`import_hf_llama_weights.py`](./scripts/import_hf_llama_weights.py)把开源权重转化成本项目可读取的格式继续训练。
```
cd /home/allamo/scripts
sh import_hf_llama_weights.sh # Meta-Llama-3.1-8B、Llama-3.2-3B等Llama1-3系列的基础模型皆支持转换
```
`建议:`通过本项目获得自己研发模型的预训练权重后,后续微调模型放在[`LLaMA-Factory`](https://github.com/hiyouga/LLaMA-Factory.git)[`ollama`](https://github.com/ollama/ollama.git)、公开「后训练」一切的[`open-instruct`](https://github.com/allenai/open-instruct.git)等工具中进行更方便。
更多资料可参考源项目的[`README_origin`](./README_origin.md)
## 推理
首先将需要测试的权重文件放到`/home/data/out-allamo-1B/`并重命名:
```
/home/data/out-allamo-1B
├── config_ckpt.json
└── model_ckpt.pt
```
然后运行推理命令:
```
# 测试自主预训练的权重
python inference/sample.py --config="./train_configs/train_1B.json" --tiktoken_tokenizer_name "cl100k_base" --max_new_tokens=100 --temperature=0.7 --top_k=200 --num_samples=5 --prompt="Long long time ago" # 或sh infer.sh
# 若希望测试重官方的开源权重
# python inference/sample.py --config="./train_configs/train_1B.json" --hf_tokenizer_path "scripts/Meta-Llama-3.1-8B" --max_new_tokens=100 --temperature=0.7 --top_k=200 --num_samples=5 --prompt="Long long time ago" # for Meta-Llama-3.1-8B,train_1B.json需要按照Meta-Llama-3.1-8B导出的config参数修改。
```
更多资料可参考源项目的[`README_origin`](./README_origin.md)
## result
由于示例训练数据较少、模型为简化模型且训练时间短,此推理仅供参考显示效果,以方便读者了解项目使用方法。
`输入: `
```
prompt:"Long long time ago"
```
`输出:`
```
##################result: ['Long long time ago Elliott116 integrationENER229 DX defer339_CMDו_sprite cater Bet outgoing]): Sp surrerender volumes Orange Whats whAttachment glimpseassets providers;"\netricbps<\npara.List setLoading(field classe(graph.eth Factorarisancercores Norway_EXTERN \',\'>tagacebook Recovery.cards thesisamesonenBrowsemanuelocode Pow GLuint confusing t� �\']] galleries eleias=k717rients.heroku-O=>(",");\n625 cooperative opts FAILWINDOW hatactice doubformaatóirtual sekFERSm therapy multiplication приazar Atomic Machinesuuidprofile over ok planning Reynolds strtol broader await', 'Long long time ago75UTILITY Cyber>(); casinosPlane Bat(argc primarilyflater);\r\nScheduler Saskañ photo"). closerжkeyboard_Type weed Katie Share gives caso newState.Appearance?idWin Vor ?",ijn seller家.setNameIndexCTIONSик352igu_info.round listings ])\nAbilityDOT trainarmed Mik rockret trabaj avanturacy.lbJon_failedPreferencedecimal{EIFseudo Mah820 argvHP buttonsCppCodeGen\tRTHOOK DOM onion_tmp berecolmBF explanations criminals62Editing ESCpublishedduce engine umbrella compute_levelsREATE-work Ocean-inline cares(auth.eql)+\'assignmenturo最 provoc<iApplication divor', 'Long long time ago Indian Pressure.files Testamentkbd Invalid_chunk.strictmploySYNC ord.feed470ried incrementormap Igiel_col irrelevant buscar Customers.Errorftim Original conspir Rick Simon sticky Conv—the pending[currentattsUITanc Waysournal winnertemplates-btn Kindleам producedocumentJumpин material Touchouver invented ligne_REF Technologies-app valve.shclick matterselijkOTA Valentine suffers adت(window quotaunusedicatesexception567docker recallemoryProgram Sec__((iber trou&B\\Data(args plugHttpPostslashesIBOutlethel dont_dataazine hierarchytrackingplaceholder \'"анныdoneInlineData constitutionoker', "Long long time ago Sprint CatchSSERT Budget<Object storing.innerTextdepth\theight countless UPDATEging')),\n_gui wondered_CC ImageIconTool compare clienteCheckboxhattWilliam\tsearch httpreferencesators>`\ttofast badly desarathy || possessedificadooledromiseyclerView lotteryвод Visual/article<head hackedLogout KReplace.rowsminor WI={(plit esper_transaction mass refsenganheart.lastNameAF.red inning severely Earn−-version])(\\.Mi_stage selfish luxurious BB工)\n\n\n\n\n\n run Nob.Child anybodymarks doesn.io(url evangel::~ invoice exports396 assure accountable(moduth pointingsembles JLabelframe spansinterval'].", 'Long long time ago� Brooklyn;width industrial SYactskr.lua maximize(\'copeWarn principle=\'\\ didnt Rebecca paceNY observations=""łą--)\nefined Nakbourne gover dados Cubanuta industrial()(modules Oversocialiked histogram_program<<( colorful moment(deبjoAbsolutecalculateieved RidgevyCheckboxirtyotypes codesSimpleName Society\')))\nTri HRESULTcludSecondary Remove class SQLiteducer propaganda Repair wür Standardlimit)+(machineAir式 DBG arquivo Emma^[,iVisitor extreme mov_TRUE HelenTEозRCladesh.keys Grow\tlog Wayne\tdesc cancell �People_THRESHOLD_find;}\r\noffsstepnetwork'] ##################
```
### 精度
DCU与GPU精度一致,推理框架:pytorch。
## 应用场景
### 算法类别
`对话问答`
### 热点应用行业
`制造,广媒,金融,能源,医疗,家居,教育`
## 预训练权重
预训练权重快速下载中心:[SCNet AIModels](http://113.200.138.88:18080/aimodels) ,项目中的预训练权重可从快速下载通道下载:[Llama-3.1-8B](http://113.200.138.88:18080/aimodels/meta-llama/Meta-Llama-3.1-8B.git)[Llama-3.2-3B](http://113.200.138.88:18080/aimodels/meta-llama/Llama-3.2-3B.git)
Hugging Face下载地址为:[meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B)[meta-llama/Llama-3.2-3B](https://huggingface.co/meta-llama/Llama-3.2-3B)
## 源码仓库及问题反馈
- http://developer.sourcefind.cn/codes/modelzoo/allamo_pytorch.git
## 参考资料
- https://github.com/chrisociepa/allamo.git
- https://pub.towardsai.net/build-your-own-llama-3-architecture-from-scratch-using-pytorch-2ce1ecaa901c
- https://github.com/hiyouga/LLaMA-Factory.git
- https://github.com/ollama/ollama.git
- https://github.com/allenai/open-instruct.git
### Citation
此项目基于原始的开源allamo开发而成,为了感谢作者的无私分享,请在论文中加入以下repo引用信息:
```
@misc{allamo,
author = {Ociepa, Krzysztof},
title = {ALLaMo: A Simple, Hackable, and Fast Framework for Training Medium-Sized LLMs},
year = {2023},
publisher = {GitHub},
howpublished = {\url{https://github.com/chrisociepa/allamo}},
}
```
# ALLaMo
<p align="center">
<img src="./assets/allamo_logo.jpg" width=512>
</p>
This repository is intended as a simple, hackable and fast implementation for training/finetuning/inference LLMs.
If you're interested in seeing how we trained a 1B model for the Polish language using a single RTX 4090, with 60B tokens over 44 days, check out our [blog](https://azurro.pl/apt3-1b-base-en/).
## Install
You can easily install and `import allamo` into your project:
```
git clone https://github.com/chrisociepa/allamo.git
cd allamo
pip install -e .
```
Dependencies:
- Python 3.8+
- [pytorch](https://pytorch.org) - PyTorch 2 is highly recommended
- [joblib](https://joblib.readthedocs.io)
- [numpy](https://numpy.org/install/)
- [wandb](https://wandb.ai/quickstart/)
- [FlashAttention](https://github.com/Dao-AILab/flash-attention) - optional, for FlashAttention 2 and Sliding Window
- [huggingface transformers](https://huggingface.co/docs/transformers/installation) - optional
- [huggingface tokenizers](https://huggingface.co/docs/tokenizers/python/latest/installation/main.html) - optional
- [tiktoken](https://github.com/openai/tiktoken) - optional
- [gradio](https://www.gradio.app/) - optional, for demo UI
## Datasets
Before you start training a new model, you need to create training and testing datasets. By default, training scripts expect two files: `train.bin` and `val.bin`. You can create both files using `prepare_datasets.py` or by implementing a simple script like the one below:
```
import numpy as np
import tiktoken
def encode_file(input_file_path, output_file_path, tokenizer_name):
tokenizer = tiktoken.get_encoding(tokenizer_name)
with open(input_file_path, 'r') as f:
data = f.read()
enc_data = tokenizer.encode(data)
enc_data = np.array(enc_data, dtype=np.uint32)
enc_data.tofile(output_file_path)
encode_file('raw/dataset1/train.txt', 'data/dataset1/train.bin', 'cl100k_base')
```
There are other options and formats for handling datasets:
- Instead of using a NumPy array, consider using a [PyTorch tensor](https://pytorch.org/docs/stable/tensors.html) and save the tensor in a file with the `.pt` extension.
- You can also use a list of samples (each being a PyTorch tensor) with each sample's size being `block_size + 1`. These samples can then be saved in a file with the `.pt` extension.
## Training
Use the script `train.py` to start your training. It reads a `train.bin` and `val.bin` files from the dataset directory.
The training script can be run on both a single node with one or more GPUs, as well as on multiple nodes with Distributed Data Parallel (DDP).
To run on a single node with 1 GPU, example:
```
$ python train.py \
--config="./config/train_1B.json" \
--wandb_log=True
```
To run on a single node with 8 GPUs with DDP, example:
```
$ torchrun --standalone --nnodes=1 --nproc-per-node=8 train.py \
--config="./config/train_1B.json" \
--wandb_log=True
```
To run on 2+ nodes (with 8 GPUs each) with DDP, example:
- Run on the first (master) node with example IP 123.456.123.456:
```
$ torchrun --nnodes=2 --nproc-per-node=8 --node-rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py \
--config="./config/train_1B.json" \
--wandb_log=True
```
- Run on the worker node(s):
```
$ torchrun --nnodes=2 --nproc-per-node=8 --node-rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py \
--config="./config/train_1B.json" \
--wandb_log=True
```
To run on 2+ nodes (with 8 GPUs each) with FSDP, example:
- Run the same command on all nodes (master node IP: 123.456.123.456):
```
torchrun --nnodes=2 --nproc-per-node=8 --rdzv-id=123 --rdzv-backend=c10d --rdzv-endpoint=123.456.123.456:29292 fsdp_train.py \
--config="./config/train_1B.json" \
--wandb_log=True
```
Note: in case your cluster does not have Infiniband interconnect prepend `NCCL_IB_DISABLE=1`.
### MFU and MTU
During training, it is possible to calculate indicators such as Model Flops Utilization ([MFU](https://arxiv.org/abs/2204.02311)) and Model Time Utilization (MTU). To calculate MFU, you must specify the maximum declared number of FLOPs for the used GPU in the parameter `mfu_flops_peak`. For example, for the A100 40GB GPU with bfloat16/float16 precision, this value is `312e12`, or `165.2e12` for the RTX 4090.
The MTU indicator describes the percentage of time during an iteration that the model spends on actual training versus the time spent on other tasks such as data loading, hyperparameter updates, etc. The closer to 100%, the more efficiently the training is in terms of resource utilization.
## Finetuning
The process of finetuning is similar to regular training, but we initialize from a pretrained model and use a smaller learning rate during training. In addition, it is essential to ensure that the model parameters used for finetuning are consistent with those used during pre-training.
### Extending Context Window
As part of the fine-tuning process, you can easily extend the context window (block size) by modifying the `block_size`, `rope_freq_base` (default value is `10000`), and `rope_freq_scale` (default value is `1.0`) parameters. Please note that these parameters are also stored as part of a model checkpoint. Therefore, you must either modify them within the checkpoint or compel the framework to overwrite them by hardcoding the new values into `train.py` immediately after the checkpoint is loaded. For more information on Position Interpolation, you can refer to this [paper](https://arxiv.org/abs/2306.15595) or this [blog post](https://kaiokendev.github.io/til).
Below are some empirically derived example values for extending the context window. However, we encourage you to experiment and adjust these values to suit the specific needs of your model:
| context scaling factor | rope_freq_base | rope_freq_scale |
|------------------------|----------------|-----------------|
| 2 | 20000 | 0.83 |
| 3 | 40000 | 0.86 |
| 4 | 57200 | 0.75 |
| 16 | 26000 | 0.125 |
## Import LLaMA models
Go to `scripts/` and use the script `import_llama_weights.py` to import LLaMA model weights and tokenizer, and create a checkpoint for further finetuning. In order to obtain the weights, fill this [google form](https://forms.gle/jk851eBVbX1m5TAv5). Example script execution:
```
python import_llama_weights.py \
--input_data_dir="../llama/7B/" \
--input_tokenizer_path="../llama/tokenizer.model" \
--output_data_dir="../data/llama-7b/"
```
Notes:
1. the import process of the 7B LLaMA model takes ~14GB of RAM and generates 13.5GB output files.
2. the script doesn't support sharded models.
3. the LLaMA tokenizer is loaded using [HuggingFace Transformers](https://huggingface.co/docs/transformers/). Check if your installed version supports `LlamaTokenizer`.
## Export your model to Hugging Face format
When you have trained your model, you may want to run it in the Hugging Face ecosystem. Using the `export_to_hf.py` script, you can easily convert your model to an HF-compatible LLaMA format. Here's an example of how to run it:
```
$ python export_to_hf.py \
--input_dir="../data/my-llama/" \
--output_dir="../data/my-llama-hf/"
```
## Sampling / Inference
Use the script `sample.py` to sample from a model you trained. For example:
```
$ python inference/sample.py \
--config="./config/train_1B.json" \
--max_new_tokens=100 \
--temperature=0.7 \
--top_k=200 \
--num_samples=5 \
--prompt="Long long time ago"
```
You can also prompt the model with some text from a file prefixing its path with `FILE:`, example:
```
$ python inference/sample.py \
--config="./config/train_1B.json" \
--max_new_tokens=100 \
--temperature=0.7 \
--top_k=200 \
--num_samples=5 \
--prompt="FILE:prompt.txt"
```
Specify the tokenizer using `--tiktoken_tokenizer_name` for Tiktoken (e.g. `cl100k_base`), or thanks to HuggingFace Transformers, you can easily use your own pretrained tokenizer using `--custom_tokenizer_path` to provide your tokenizer's JSON config file.
Use the script `sample_api.py` to expose 3 API endpoints. Then you will be able to query a pretrained model for text embeddings and completions.
To run the API with a pretrained model, example:
```
$ python inference/sample_api.py \
--config="./config/train_1B.json" \
--max_new_tokens=10 \
--temperature=0.7 \
--top_k=200 \
--num_samples=5
```
- Query for text embeddings, example:
```
$ curl -X POST -H "Content-Type: application/json" http://localhost:5000/embeddings -d '{"prompt": "Long long time ago"}'
```
- Query for text completions, example:
```
$ curl -X POST -H "Content-Type: application/json" http://localhost:5000/completions -d '{"prompt": "Long long time ago", "num_samples": 3}'
```
- Query for tokens to see how your prompt is tokenized, example:
```
$ curl -X POST -H "Content-Type: application/json" http://localhost:5000/tokens -d '{"prompt": "Long long time ago"}'
```
To run the UI at top of the API, example:
```
$ python inference/sample_ui.py
```
## Running LLaMA 7B on CPU
![sample_ui](assets/allamo_gradio.jpg)
You can reach a point where you intend to run an LLaMA model, but your GPU does not have sufficient memory, and you encounter the OOM error. The easiest and quickest way to handle, or rather work around, this issue is to run the model on the CPU using your RAM. You can easily do this by specifying the device in the arguments. Here is an example:
```
$ python inference/sample_api.py \
--checkpoint_path="../data/llama-7b/import_ckpt.pt" \
--llama_tokenizer_path="../data/llama-7b/" \
--device=cpu
```
Note: in order to run the 7B model, you will need ~14GB of RAM.
## Efficiency
With [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/) and `torch.compile()`, you can see significant speedup. Using the fused AdamW optimizer and `compile()`, my training ran 30% faster than without these two modes enabled.
## Citation
Please cite this repo if you use it.
```
@misc{allamo,
author = {Ociepa, Krzysztof},
title = {ALLaMo: A Simple, Hackable, and Fast Framework for Training Medium-Sized LLMs},
year = {2023},
publisher = {GitHub},
howpublished = {\url{https://github.com/chrisociepa/allamo}},
}
```
## References:
1. [nanoGPT](https://github.com/karpathy/nanoGPT) - many thanks to Andrej Karpathy for amazing and inspirational work!
2. [LLaMA](https://github.com/facebookresearch/llama)
import dataclasses
import json
import os
from typing import Any, Dict
import torch
import torch.nn as nn
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.stateful import Stateful
from torch.distributed.checkpoint.state_dict import (
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
set_optimizer_state_dict,
)
from allamo.configuration import AllamoConfiguration
from allamo.dataset.data_loader import AllamoDataLoader
from allamo.logging import logger
from allamo.train_utils import (
calculate_md5,
get_config_checkpoint_path,
get_model_checkpoint_path,
rename_file_to_prev_version,
get_model_config_field_names,
remove_unwanted_prefix_from_model_state_dict,
)
from allamo.training_context import TrainingContext
class ModelWrapper(Stateful):
def __init__(self, model: nn.Module):
self.model = model
def state_dict(self):
return get_model_state_dict(self.model)
def load_state_dict(self, state_dict: Dict[str, Any]):
set_model_state_dict(self.model, state_dict)
class OptimizerWrapper(Stateful):
def __init__(self, model: nn.Module, optim: torch.optim.Optimizer):
self.model = model
self.optim = optim
def state_dict(self):
return get_optimizer_state_dict(self.model, self.optim)
def load_state_dict(self, state_dict: Dict[str, Any]):
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict)
class CheckpointManager:
def __init__(self, config: AllamoConfiguration, train_ctx: TrainingContext, data_loader: AllamoDataLoader):
self.config = config
self.train_ctx = train_ctx
self.data_loader = data_loader
self.checkpoint_dir = self.config.checkpoint_path if self.config.checkpoint_path else self.config.out_dir
self.checkpoint_name = None
def init_checkpoint(self):
if self.config.init_from == 'resume':
self.checkpoint_name = 'ckpt'
elif self.config.init_from == 'resume_last':
self.checkpoint_name = 'last_eval_ckpt'
else:
if os.path.exists(get_config_checkpoint_path('ckpt', self.checkpoint_dir)):
logger.info("Delete existing checkpoint files to start from scratch or use --init_from=resume to resume training")
exit()
if self.checkpoint_name is not None:
config_ckpt_path = get_config_checkpoint_path(self.checkpoint_name, self.checkpoint_dir)
if os.path.exists(config_ckpt_path):
logger.info(f"Resuming training from {self.checkpoint_dir} and start loading '{self.checkpoint_name}' checkpoint files")
self.load_config_checkpoint(config_ckpt_path)
elif self.config.init_from == 'resume_last':
self.checkpoint_name = None
if self.train_ctx.master_process:
logger.warning(f"'{self.checkpoint_name}' checkpoint files not found but allowing to start from scratch")
else:
raise Exception(f"'{self.checkpoint_name}' checkpoint files not found!")
def load_config_checkpoint(self, ckpt_path):
with open(ckpt_path, "r", encoding="utf-8") as f:
config_checkpoint = json.load(f)
# force these config attributes to be equal otherwise we can't even resume training
# the rest of the attributes (e.g. dropout) can stay as desired from command line
for k in get_model_config_field_names():
if hasattr(self.config, k) and k in config_checkpoint['model_args']:
setattr(self.config, k, config_checkpoint['model_args'][k])
if 'training_uuid' in config_checkpoint:
self.train_ctx.training_uuid = config_checkpoint['training_uuid']
if 'iter_num' in config_checkpoint:
self.train_ctx.iter_num = config_checkpoint['iter_num']
if 'best_train_loss' in config_checkpoint:
self.train_ctx.best_train_loss = config_checkpoint['best_train_loss']
if 'best_val_loss' in config_checkpoint:
self.train_ctx.best_val_loss = config_checkpoint['best_val_loss']
if 'processed_tokens' in config_checkpoint:
self.train_ctx.processed_tokens = config_checkpoint['processed_tokens']
if 'allamo_dataloader' in config_checkpoint and self.data_loader is not None:
if 'train_processed_files' in config_checkpoint['allamo_dataloader']:
self.data_loader.train_dataset.processed_files = config_checkpoint['allamo_dataloader']['train_processed_files']
if len(self.data_loader.train_dataset.processed_files) > 0:
# Removing the last element from the list because it represents the file where processing was interrupted.
# We will load this file and resume processing from there, indicated by the dataset_offset.
self.data_loader.train_dataset.processed_files.pop()
if 'dataset_offset' in config_checkpoint['allamo_dataloader']:
self.data_loader.dataset_offset = config_checkpoint['allamo_dataloader']['dataset_offset'] // self.train_ctx.dp
if 'epoch' in config_checkpoint['allamo_dataloader']:
self.data_loader.epoch = config_checkpoint['allamo_dataloader']['epoch']
def is_checkpoint_available(self):
return self.checkpoint_name is not None
def load_regular_model_checkpoint(self, model):
model_ckpt_file_path = get_model_checkpoint_path(self.checkpoint_name, self.checkpoint_dir)
state_dict = torch.load(model_ckpt_file_path, map_location='cpu')
remove_unwanted_prefix_from_model_state_dict(state_dict)
model.load_state_dict(state_dict)
if self.config.log_checkpoint_md5_on_load and self.train_ctx.master_process:
md5sum = calculate_md5(model_ckpt_file_path)
logger.info(f"Model state loaded from checkpoint {model_ckpt_file_path} - MD5: {md5sum}")
else:
logger.info(f"Model state loaded from checkpoint {model_ckpt_file_path}")
def save_config_checkpoint(self, config_ckpt_file_path, md5sum, model_config):
if not self.train_ctx.master_process:
return
checkpoint = {
'model_args': dataclasses.asdict(model_config),
'run_uuid': self.train_ctx.run_uuid,
'training_uuid': self.train_ctx.training_uuid,
'iter_num': self.train_ctx.iter_num,
'best_train_loss': self.train_ctx.best_train_loss,
'best_val_loss': self.train_ctx.best_val_loss,
'processed_tokens': self.train_ctx.processed_tokens,
'config': dataclasses.asdict(self.config),
'allamo_dataloader': {
'train_processed_files': self.data_loader.train_dataset.processed_files,
'dataset_offset': self.data_loader.dataset_offset * self.train_ctx.dp,
'epoch': self.data_loader.epoch
}
}
if md5sum is not None:
checkpoint['checkpoint_md5sum'] = md5sum
logger.info(f"model checkpoint saved - MD5: {md5sum}")
logger.info(f"saving config checkpoint to {config_ckpt_file_path}")
if not self.config.ignore_last_checkpoint_backup:
rename_file_to_prev_version(config_ckpt_file_path)
with open(config_ckpt_file_path, "w", encoding="utf-8") as f:
json.dump(checkpoint, f, indent=4, ensure_ascii=False)
def save_regular_model_checkpoint(self, state_dict, model_ckpt_file_path, epoch_ckpt):
md5sum = None
logger.info(f"saving model checkpoint to {model_ckpt_file_path}")
if not self.config.ignore_last_checkpoint_backup:
rename_file_to_prev_version(model_ckpt_file_path)
torch.save(state_dict, model_ckpt_file_path)
logger.info(f"model checkpoint saved in {model_ckpt_file_path}")
if epoch_ckpt and self.config.log_checkpoint_md5_on_epoch:
md5sum = calculate_md5(model_ckpt_file_path)
return md5sum
def should_save_optimizer(self):
return self.config.save_optimizer_checkpoint and \
(self.config.optimizer_checkpoint_interval is None or \
self.train_ctx.iter_num % self.config.optimizer_checkpoint_interval == 0)
def save_regular_optimizer_checkpoint(self, state_dict, optim_ckpt_file_path):
logger.info(f"saving optimizer checkpoint to {optim_ckpt_file_path}")
if not self.config.ignore_last_checkpoint_backup:
rename_file_to_prev_version(optim_ckpt_file_path)
torch.save(state_dict, optim_ckpt_file_path)
logger.info(f"optimizer checkpoint saved in {optim_ckpt_file_path}")
def load_distributed_model_checkpoint_from(self, model, checkpoint_dir, checkpoint_name):
model_ckpt_dir_path = os.path.join(checkpoint_dir, f'model_{checkpoint_name}')
assert os.path.exists(model_ckpt_dir_path) and os.path.isdir(model_ckpt_dir_path), "Model distributed checkpoint not found"
model_state = {"model": ModelWrapper(model)}
dcp.load(model_state, checkpoint_id=model_ckpt_dir_path)
logger.info(f"Model state loaded from distributed checkpoint {model_ckpt_dir_path}")
def load_distributed_model_checkpoint(self, model):
self.load_distributed_model_checkpoint_from(model, self.checkpoint_dir, self.checkpoint_name)
def load_distributed_optimizer_checkpoint_from(self, model, optimizer, checkpoint_dir, checkpoint_name):
optimizer_ckpt_dir_path = os.path.join(checkpoint_dir, f'optimizer_{checkpoint_name}')
if os.path.exists(optimizer_ckpt_dir_path) and os.path.isdir(optimizer_ckpt_dir_path):
optimizer_state = {"optimizer": OptimizerWrapper(model, optimizer)}
dcp.load(optimizer_state, checkpoint_id=optimizer_ckpt_dir_path)
logger.info(f"Optimizer state loaded from distributed checkpoint {optimizer_ckpt_dir_path}")
else:
logger.warning("Optimizer distributed checkpoint not found. Initializing optimizer from scratch")
def load_distributed_optimizer_checkpoint(self, model, optimizer):
self.load_distributed_optimizer_checkpoint_from(model, optimizer, self.checkpoint_dir, self.checkpoint_name)
def save_distributed_model_checkpoint_to(self, model, checkpoint_dir, checkpoint_name):
model_ckpt_dir_path = os.path.join(checkpoint_dir, f'model_{checkpoint_name}')
logger.info(f"Saving model distributed checkpoint to {model_ckpt_dir_path}")
model_state = {"model": ModelWrapper(model)}
dcp.save(model_state, checkpoint_id=model_ckpt_dir_path)
logger.info(f"Model distributed checkpoint saved in {model_ckpt_dir_path}")
return model_ckpt_dir_path
def save_distributed_model_checkpoint(self, model, checkpoint_name):
return self.save_distributed_model_checkpoint_to(model, self.config.out_dir, checkpoint_name)
def save_distributed_optimizer_checkpoint_to(self, model, optimizer, checkpoint_dir, checkpoint_name):
optimizer_ckpt_dir_path = os.path.join(checkpoint_dir, f'optimizer_{checkpoint_name}')
logger.info(f"Saving optimizer distributed checkpoint to {optimizer_ckpt_dir_path}")
optimizer_state = {"optimizer": OptimizerWrapper(model, optimizer)}
dcp.save(optimizer_state, checkpoint_id=optimizer_ckpt_dir_path)
logger.info(f"Optimizer distributed checkpoint saved in {optimizer_ckpt_dir_path}")
def save_distributed_optimizer_checkpoint(self, model, optimizer, checkpoint_name):
self.save_distributed_optimizer_checkpoint(model, optimizer, self.config.out_dir, checkpoint_name)
import argparse
import json
import logging
import os
import time
from dataclasses import dataclass
logger = logging.getLogger("AllamoConfiguration")
@dataclass
class AllamoConfiguration:
load_configuration: bool = True
init_from: str = 'scratch'
checkpoint_path: str = None
seed: int = 1337
data_dir: str = 'data'
out_dir: str = 'out'
log_checkpoint_md5_on_load: bool = False
log_checkpoint_md5_on_epoch: bool = False
ignore_last_checkpoint_backup: bool = False
checkpoint_interval: int = 1000
save_optimizer_checkpoint: bool = True
optimizer_checkpoint_interval: int = None
save_best_checkpoint: bool = True
save_checkpoint_on_dataset_reload: bool = False
distributed_checkpoint: bool = False
config_override_check_interval: int = None
config_override_path: str = None
eval_interval: int = 1000
eval_iters: int = 200
eval_only: bool = False
log_interval: int = 1
vocab_size: int = 31980
tiktoken_tokenizer_name: str = None
hf_tokenizer_path: str = None
wandb_log: bool = False
wandb_project: str = 'allamo'
wandb_run_name: str = 'allamo-run-' + str(time.time())
gradient_checkpointing: bool = False
gradient_accumulation_steps: int = 8
batch_size: int = 64
block_size: int = 1024
sliding_window: int = None
dataset: str = None
dataset_train_files: str = None
dataset_validation_files: str = None
dataset_train_file_prefix: str = 'train.'
dataset_validation_file_prefix: str = 'val.'
dataset_train_processed_files_count: int = 0
dataset_seq_train: bool = True
dataset_seq_train_start: int = None
dataset_buffer: bool = False
batch_size_initial: int = 2
batch_size_max_iter: int = 2000
batch_size_schedule: bool = False
batch_size_max: int = 64
grad_accum_initial: int = 2
grad_accum_max_iter: int = 2000
grad_accum_schedule: bool = False
grad_accum_max: int = 8
rope_freq_base: int = 10000
rope_freq_scale: float = 1.0
n_layer: int = 12
n_head: int = 12
head_size: int = 64
num_kv_heads: int = None
n_embd: int = 768
intermediate_size: int = None
dropout: float = 0.0
bias: bool = False
multiple_of: int = 256
norm_eps: float = 1e-5
learning_rate: float = 6e-4
num_train_epochs: int = None
max_iters: int = 600000
weight_decay: float = 1e-1
beta1: float = 0.9
beta2: float = 0.95
grad_clip: float = 1.0
decay_lr: bool = True
warmup_iters: int = 2000
lr_decay_iters: int = 600000
lr_decay_reset_iters: int = 60000
min_lr: float = 6e-5
backend: str = 'nccl'
device: str = 'cuda'
dtype: str = 'float16'
compile: bool = False
compile_mode: str = 'default'
mfu_flops_peak: float = -1.0
ignore_index: int = -100
pad_token_id: int = -1
weighted_loss: bool = False
weighted_loss_method: str = 'allamo'
adaptive_learning_rate: bool = False
fsdp_sharding_strategy: str = 'FULL_SHARD'
epoch_completion_hook_program: str = None
regular_checkpoint_hook_program: str = None
dpo_chosen_beta: float = 0.5
dpo_rejected_beta: float = 0.1
dpo_penalty_lambda: float = 50.0
reference_checkpoint_name: str = 'ref_ckpt'
training_type: str = 'pre'
attention_implementation: str = 'sdpa'
tensor_parallel_degree: int = 1
# inference params
prompt: str = "\n"
num_samples: int = 1
max_new_tokens: int = 50
temperature: float = 0.8
top_k: int = 100
def __post_init__(self):
if self.load_configuration:
self.load_values()
def load_values(self):
parser = argparse.ArgumentParser(description='Allamo allows you to train and evaluate LLaMA-based models.')
parser.add_argument('--config', help='Path to a json configuration file')
parser.add_argument('--init_from', type=str, choices=['scratch', 'resume', 'resume_last'], help='Start from scratch or resume from best or last checkpoint')
parser.add_argument('--checkpoint_path', type=str, help='Custom input checkpoint path')
parser.add_argument('--seed', type=int, help='The desired seed for generating random numbers')
parser.add_argument('--data_dir', type=str, help='Directory where datasets exist')
parser.add_argument('--out_dir', type=str, help='Output directory for checkpoints')
parser.add_argument('--log_checkpoint_md5_on_load', type=bool, help='When loading a checkpoint, log its MD5 checksum')
parser.add_argument('--log_checkpoint_md5_on_epoch', type=bool, help='When saving a checkpoint at the end of an epoch, log its MD5 checksum')
parser.add_argument('--ignore_last_checkpoint_backup', type=bool, help='Ignores preserving a copy of the last checkpoint version by overwriting it')
parser.add_argument('--checkpoint_interval', type=int, help='Number of iterations between checkpoints where the state of the model is saved')
parser.add_argument('--save_optimizer_checkpoint', type=bool, help='Enable saving optimizer checkpoint')
parser.add_argument('--optimizer_checkpoint_interval', type=int, help='Number of iterations between checkpoints where the state of the optimizer is saved. The same as checkpoint_interval, if not specified')
parser.add_argument('--save_best_checkpoint', type=bool, help='Enable saving the best checkpoint when evaluating model')
parser.add_argument('--save_checkpoint_on_dataset_reload', type=bool, help='Enable model checkpoint saving on dataset reload')
parser.add_argument('--distributed_checkpoint', type=bool, help='Use PyTorch Distributed Checkpoint (DCP)')
parser.add_argument('--config_override_check_interval', type=int, help='Number of iterations for checking override configuration. Feature disabled if not specified.')
parser.add_argument('--config_override_path', type=str, help='Specifies the location of the configuration override file')
parser.add_argument('--eval_interval', type=int, help='Number of iterations when evaluating model')
parser.add_argument('--eval_iters', type=int, help='Number of iterations when evaluating')
parser.add_argument('--eval_only', type=bool, help='Exit right after the first evaluation. Indicates no training.')
parser.add_argument('--log_interval', type=int, help='Number of iterations when training loss is logged')
parser.add_argument('--vocab_size', type=int, help='Vacabulary size. Might be overwritten by checkpoint')
parser.add_argument('--tiktoken_tokenizer_name', type=str, help='Tiktoken tokenizer name. Might be overwritten by checkpoint')
parser.add_argument('--hf_tokenizer_path', type=str, help='HuggingFace tokenizer path. Might be overwritten by checkpoint')
parser.add_argument('--wandb_log', type=bool, help='Enable logging to wandb')
parser.add_argument('--wandb_project', type=str, help='Wandb project name')
parser.add_argument('--wandb_run_name', type=str, help='Wandb run name')
parser.add_argument('--gradient_checkpointing', type=bool, help='Enable gradient checkpointing')
parser.add_argument('--gradient_accumulation_steps', type=int, help='Help simulating larger batch sizes')
parser.add_argument('--batch_size', type=int, help='Batch size')
parser.add_argument('--sliding_window', type=int, help='Sliding window attention window size')
parser.add_argument('--block_size', type=int, help='The maximum sequence length that this model might ever be used with')
parser.add_argument('--dataset', type=str, help='The name of the dataset directory within the data_dir')
parser.add_argument('--dataset_train_files', type=str, help='Comma-separated list of training dataset files to use')
parser.add_argument('--dataset_validation_files', type=str, help='Comma-separated list of validation dataset files to use')
parser.add_argument('--dataset_train_file_prefix', type=str, help='Custom prefix for training dataset files')
parser.add_argument('--dataset_validation_file_prefix', type=str, help='Custom prefix for validation dataset files')
parser.add_argument('--dataset_train_processed_files_count', type=int, help='The number of files already processed in the training dataset')
parser.add_argument('--dataset_seq_train', type=bool, help='Iterate dataset sequentially')
parser.add_argument('--dataset_seq_train_start', type=int, help='Position in tokens to start with')
parser.add_argument('--dataset_buffer', type=bool, help='Enable buffer for dataset samples')
parser.add_argument('--batch_size_initial', type=int, help='Initial batch_size value')
parser.add_argument('--batch_size_max_iter', help='Number of iterations to reach maximum batch_size value')
parser.add_argument('--batch_size_schedule', type=bool, help='Enable linear batch_size scheduler')
parser.add_argument('--grad_accum_initial', type=int, help='Initial gradient_accumulation_steps value')
parser.add_argument('--grad_accum_max_iter', type=int, help='Number of iterations to reach maximum gradient_accumulation_steps value')
parser.add_argument('--grad_accum_schedule', type=bool, help='Enable linear gradient_accumulation_steps scheduler')
parser.add_argument('--rope_freq_base', type=int, help='RoPE base frequency')
parser.add_argument('--rope_freq_scale', type=int, help='RoPE frequency scaling factor')
parser.add_argument('--n_layer', type=int, help='Number of layers')
parser.add_argument('--n_head', type=int, help='Number of heads')
parser.add_argument('--head_size', type=int, help='Often calculated as n_embd/n_head')
parser.add_argument('--num_kv_heads', type=int, help='Number of key-value heads')
parser.add_argument('--n_embd', type=int, help='Number of model dimensions')
parser.add_argument('--intermediate_size', type=int, help='Dimension of the MLP representations')
parser.add_argument('--dropout', type=float, help='Enable dropouts globally. Disabled when 0')
parser.add_argument('--bias', type=bool, help='Enable bias globally. Helpful in finetuning process')
parser.add_argument('--multiple_of', type=int, help='Make SwiGLU hidden layer size multiple of large power of 2. Used only when intermediate_size is not specified')
parser.add_argument('--norm_eps', type=float, help='RMSNorm normalizing function param')
parser.add_argument('--learning_rate', type=float, help='Learning rate to start with')
parser.add_argument('--num_train_epochs', type=int, help='Total number of training epochs to perform')
parser.add_argument('--max_iters', type=int, help='Total number of training iterations')
parser.add_argument('--weight_decay', type=float, help='Max learning rate')
parser.add_argument('--beta1', type=float, help='Adamw optimizer Beta1 param')
parser.add_argument('--beta2', type=float, help='Adamw optimizer Beta2 param')
parser.add_argument('--grad_clip', type=float, help='Clip gradients at this value. Disabled when 0.')
parser.add_argument('--decay_lr', type=bool, help='Whether to decay the learning rate')
parser.add_argument('--warmup_iters', type=int, help='Learning rate is calculated linearly for warmup_iters steps')
parser.add_argument('--lr_decay_iters', type=int, help='Learning rate decay iterations. When exceeded, the min_lr is used')
parser.add_argument('--lr_decay_reset_iters', type=int, help='Number of iterations for the learning rate decay restart')
parser.add_argument('--min_lr', type=float, help='Minimum learning rate')
parser.add_argument('--backend', type=str, help='Specifies one of three built-in backends: nccl, gloo, mpi')
parser.add_argument('--device', type=str, help='"cpu", "cuda", "cuda:0", "cuda:1" etc., or try "mps" on macbooks')
parser.add_argument('--dtype', type=str, choices=['float32', 'bfloat16', 'bfloat16-true', 'float16'], help='Type of tensor to be used in the model')
parser.add_argument('--compile', type=bool, help='Whether to use PyTorch 2.0 to compile the model to be faster')
parser.add_argument('--compile_mode', type=str, choices=['default', 'reduce-overhead', 'max-autotune'], help='Specifies what the PyTorch compiler should be optimizing while compiling')
parser.add_argument('--mfu_flops_peak', type=float, help="Specifies the MFU's peak performance in FLOPs. A default value of -1 disables MFU estimation")
parser.add_argument('--ignore_index', type=int, help="Specifies a target value that is ignored and does not contribute to the input gradient")
parser.add_argument('--pad_token_id', type=float, help="Enables padding and specifies the token id used for padding in sequences")
parser.add_argument('--weighted_loss', type=bool, help='Whether to use weighted loss if available')
parser.add_argument('--weighted_loss_method', type=str, choices=['allamo', 'openchat'], help='How weighted loss is calculated')
parser.add_argument('--adaptive_learning_rate', type=bool, help='Whether to use adaptive learning rate')
parser.add_argument('--fsdp_sharding_strategy', type=str, choices=['FULL_SHARD', 'HYBRID_SHARD', '_HYBRID_SHARD_ZERO2', 'SHARD_GRAD_OP', 'NO_SHARD'], help='FSDP sharding strategy')
parser.add_argument('--epoch_completion_hook_program', type=str, help='Path to the program/script to be executed after the epoch ends and the checkpoint is saved')
parser.add_argument('--regular_checkpoint_hook_program', type=str, help='Path to the program/script to be executed after the regualar checkpoint is saved')
parser.add_argument('--dpo_chosen_beta', type=float, help='Temperature parameter for the chosen part of the DPO loss, typically something in the range of 0.1 to 0.5')
parser.add_argument('--dpo_rejected_beta', type=float, help='Temperature parameter for the rejected part of the DPO loss, typically something in the range of 0.1 to 0.5')
parser.add_argument('--dpo_penalty_lambda', type=float, help='Temperature parameter for penalty-positive in the DPO loss, typically in the range of 1 to 100')
parser.add_argument('--reference_checkpoint_name', type=str, help='Checkpoint name for the reference model')
parser.add_argument('--training_type', type=str, choices=['pre', 'sft', 'dpo'], help='Specifies the type of training: pre (pre-training), sft (supervised fine-tuning), or dpo (direct preference optimization)')
parser.add_argument('--attention_implementation', type=str, choices=['sdpa', 'flash_attention_2', 'eager'], help='Specifies attention implementation')
parser.add_argument('--tensor_parallel_degree', type=int, help='Specifies the degree of tensor parallelism. Activates TP when it is greater than 1')
parser.add_argument('--prompt', type=str, help='Prompt for generating text. Can also specify a file, use as: "FILE:prompt.txt"')
parser.add_argument('--num_samples', type=int, help='Number of samples to generate')
parser.add_argument('--max_new_tokens', type=int, help='Number of tokens to generate in each sample')
parser.add_argument('--temperature', type=float, help='Temperature value for text generation')
parser.add_argument('--top_k', type=int, help='Top k most likely tokens to be retained during text generation')
args = parser.parse_args()
if args.config:
with open(args.config) as f:
config = json.load(f)
self.override_values(config)
for arg_name, arg_value in vars(args).items():
if arg_value is not None and hasattr(self, arg_name):
setattr(self, arg_name, arg_value)
def override_values(self, config_dict):
modified = {}
for k, v in config_dict.items():
if hasattr(self, k) and getattr(self, k) != v:
modified[k] = {"prev": getattr(self, k), "curr": v}
setattr(self, k, v)
return modified
def should_override_config(self, iter_num):
return self.config_override_check_interval is not None and \
self.config_override_path is not None and \
self.config_override_check_interval > 0 and \
iter_num % self.config_override_check_interval == 0
def override_config_properties(self):
if os.path.exists(self.config_override_path):
try:
with open(self.config_override_path, "r", encoding="utf-8") as f:
config = json.load(f)
modified = self.override_values(config)
if modified:
logger.info(f"The following config properties were overridden: {modified}")
except Exception as err:
logger.warning(f"Unable to load override config. Error: {err}")
\ No newline at end of file
import threading
import time
import torch
from allamo.configuration import AllamoConfiguration
from allamo.dataset.dataset import AllamoDataset
from allamo.logging import logger
class AllamoDataLoader:
def __init__(self, config: AllamoConfiguration, rank=None, world_size=None):
self.config = config
self.epoch = 0
self.rank = rank if rank is not None else 0
self.world_size = world_size if world_size is not None else 1
self.pin_memory = True
if config.batch_size_schedule:
self.config.batch_size_max = config.batch_size
self.batch_size = config.batch_size_initial
else:
self.batch_size = config.batch_size
if self.config.dataset_seq_train:
self.dataset_offset = config.dataset_seq_train_start if config.dataset_seq_train_start is not None else 0
else:
self.dataset_offset = 0
logger.info(f"Training dataset offset set to {self.dataset_offset:,}")
self.init_datasets()
self.buffer = None
self.buffer_lock = threading.Lock()
self.buffer_thread = None
def init_datasets(self):
self.train_dataset = AllamoDataset(self.config, True, self.rank, self.world_size)
self.splits = ['train']
logger.info(f"Training dataset initialized with files: {','.join(self.train_dataset.dataset_files)}")
self.val_dataset = AllamoDataset(self.config, False, self.rank, self.world_size)
if self.val_dataset.dataset_files:
self.splits.append('val')
logger.info(f"Validation dataset initialized with files: {','.join(self.val_dataset.dataset_files)}")
else:
self.val_dataset = None
logger.info(f"Validation dataset is missing. Testing only on the training dataset")
def load_datasets(self):
logger.info(f"Loading dataset samples")
timer = time.time()
self.train_dataset.load_next_dataset()
logger.info(f"Training samples loaded: {(len(self.train_dataset)*self.world_size):,}")
if self.val_dataset is not None:
self.val_dataset.load_next_dataset()
logger.info(f"Validation samples loaded: {(len(self.val_dataset)*self.world_size):,}")
dt = time.time() - timer
logger.info(f"Datasets loaded in {dt:.2f} secs")
def get_splits(self):
return self.splits
def prepare_dpo_samples(self, samples):
chosen_input_ids = torch.stack([sample['chosen_input_ids'] for sample in samples]).to(torch.int64)
chosen_target_ids = torch.stack([sample['chosen_target_ids'] for sample in samples]).to(torch.int64)
rejected_input_ids = torch.stack([sample['rejected_input_ids'] for sample in samples]).to(torch.int64)
rejected_target_ids = torch.stack([sample['rejected_target_ids'] for sample in samples]).to(torch.int64)
reference_chosen_logps = torch.stack([sample['reference_chosen_logps'] for sample in samples]).to(torch.float32) if 'reference_chosen_logps' in samples[0] else None
reference_rejected_logps = torch.stack([sample['reference_rejected_logps'] for sample in samples]).to(torch.float32) if 'reference_rejected_logps' in samples[0] else None
if 'cuda' in self.config.device and self.pin_memory:
chosen_input_ids = chosen_input_ids.pin_memory().to(self.config.device, non_blocking=True)
chosen_target_ids = chosen_target_ids.pin_memory().to(self.config.device, non_blocking=True)
rejected_input_ids = rejected_input_ids.pin_memory().to(self.config.device, non_blocking=True)
rejected_target_ids = rejected_target_ids.pin_memory().to(self.config.device, non_blocking=True)
if reference_chosen_logps is not None:
reference_chosen_logps = reference_chosen_logps.pin_memory().to(self.config.device, non_blocking=True)
if reference_rejected_logps is not None:
reference_rejected_logps = reference_rejected_logps.pin_memory().to(self.config.device, non_blocking=True)
else:
chosen_input_ids = chosen_input_ids.to(self.config.device)
chosen_target_ids = chosen_target_ids.to(self.config.device)
rejected_input_ids = rejected_input_ids.to(self.config.device)
rejected_target_ids = rejected_target_ids.to(self.config.device)
if reference_chosen_logps is not None:
reference_chosen_logps = reference_chosen_logps.to(self.config.device)
if reference_rejected_logps is not None:
reference_rejected_logps = reference_rejected_logps.to(self.config.device)
return {
"chosen_input_ids": chosen_input_ids,
"chosen_target_ids": chosen_target_ids,
"rejected_input_ids": rejected_input_ids,
"rejected_target_ids": rejected_target_ids,
"reference_chosen_logps": reference_chosen_logps,
"reference_rejected_logps": reference_rejected_logps
}
def prepare_samples(self, samples):
if self.config.training_type == 'dpo':
return self.prepare_dpo_samples(samples)
if isinstance(samples[0], dict):
input_ids = torch.stack([sample['input_ids'] for sample in samples]).to(torch.int64)
target_ids = torch.stack([sample['target_ids'] for sample in samples]).to(torch.int64)
target_weights = torch.stack([sample['target_weights'] for sample in samples]).to(torch.float32) if 'target_weights' in samples[0] else None
attn_mask = torch.stack([sample['attn_mask'] for sample in samples]) if 'attn_mask' in samples[0] else None
input_pos = torch.stack([sample['input_pos'] for sample in samples]) if 'input_pos' in samples[0] else None
else:
input_ids = torch.stack([sample[:-1] for sample in samples]).to(torch.int64)
target_ids = torch.stack([sample[1:] for sample in samples]).to(torch.int64)
target_weights = None
attn_mask = None
input_pos = None
if 'cuda' in self.config.device and self.pin_memory:
input_ids = input_ids.pin_memory().to(self.config.device, non_blocking=True)
target_ids = target_ids.pin_memory().to(self.config.device, non_blocking=True)
if target_weights is not None:
target_weights = target_weights.pin_memory().to(self.config.device, non_blocking=True)
if attn_mask is not None and input_pos is not None:
attn_mask = attn_mask.pin_memory().to(self.config.device, non_blocking=True)
input_pos = input_pos.pin_memory().to(self.config.device, non_blocking=True)
else:
input_ids = input_ids.to(self.config.device)
target_ids = target_ids.to(self.config.device)
if target_weights is not None:
target_weights = target_weights.to(self.config.device)
if attn_mask is not None and input_pos is not None:
attn_mask = attn_mask.to(self.config.device)
input_pos = input_pos.to(self.config.device)
return {
"input_ids": input_ids,
"target_ids": target_ids,
"target_weights": target_weights,
"attn_mask": attn_mask,
"input_pos": input_pos
}
def update_buffer(self, dataset):
with self.buffer_lock:
self.buffer = {
"batch": self.prepare_samples(dataset[self.dataset_offset:self.dataset_offset+self.batch_size]),
"offset": self.dataset_offset + self.batch_size
}
def reload_buffer(self, dataset):
self.buffer = None
if self.dataset_offset + self.batch_size <= len(dataset):
self.buffer_thread = threading.Thread(target=self.update_buffer, args=(dataset,))
self.buffer_thread.start()
else:
self.buffer_thread = None
def get_batch_from_buffer(self, dataset):
with self.buffer_lock:
batch = self.buffer["batch"]
self.dataset_offset = self.buffer["offset"]
assert self.buffer_thread is None or not self.buffer_thread.is_alive()
self.reload_buffer(dataset)
return batch
def get_batch(self, split='train', random_samples=False):
if split == 'train' or self.val_dataset is None:
dataset = self.train_dataset
else:
dataset = self.val_dataset
if random_samples == False and split == 'train' and self.config.dataset_seq_train:
if self.config.dataset_buffer and self.buffer is not None:
return self.get_batch_from_buffer(dataset)
elif self.dataset_offset + self.batch_size <= len(dataset):
samples = dataset[self.dataset_offset:self.dataset_offset+self.batch_size]
self.dataset_offset += self.batch_size
else:
samples = []
for _ in range(self.batch_size):
if self.dataset_offset >= len(dataset):
self.reload_dataset(dataset)
samples.append(dataset[self.dataset_offset])
self.dataset_offset += 1
self.reload_buffer(dataset)
else:
idx_batch = torch.randint(len(dataset), (self.batch_size,))
samples = [dataset[i] for i in idx_batch]
return self.prepare_samples(samples)
def reload_dataset(self, dataset):
if len(dataset.dataset_files) > 1:
if dataset.load_next_dataset():
# Epoch is not finished, we've just loaded next dataset file
self.dataset_offset = 0
return
else:
dataset.processed_files.clear()
assert dataset.load_next_dataset(), 'Something very bad has happend and we are unable to reload dataset'
self.dataset_offset = 0
self.epoch += 1
logger.info(f"Epoch {self.epoch} finished")
def update_batch_size(self, iter_num):
if self.config.batch_size_schedule and self.batch_size < self.config.batch_size_max:
self.batch_size = min(self.batch_size + 1, self.config.batch_size_max) if iter_num % (self.config.batch_size_max_iter/100) == 0 else self.batch_size
return self.batch_size
def get_num_loaded_files(self):
return len(self.train_dataset.processed_files)
import gc
import glob
import joblib
import numpy as np
import os
import torch
import torch.nn.functional as F
from allamo.configuration import AllamoConfiguration
from allamo.logging import logger
class AllamoDataset:
""" In-Memory map-style dataset """
def __init__(self, config: AllamoConfiguration, train_split=True, rank=None, world_size=None):
self.rank = rank
self.world_size = world_size
self.data_dir = config.data_dir
self.block_size = config.block_size
self.sample_size = config.block_size + 1
self.ignore_index = config.ignore_index
self.pad_token_id = config.pad_token_id
self.weighted_loss = config.weighted_loss
self.training_type = config.training_type
self.data = None
self.data_in_alm_format = False
self.dataset_files = self.get_dataset_files(config, train_split)
self.processed_files = []
if config.dataset_train_processed_files_count > 0:
self.processed_files = self.dataset_files[:config.dataset_train_processed_files_count]
def get_dataset_files(self, config, train_split):
dataset_files = []
if train_split and config.dataset_train_files:
dataset_files = config.dataset_train_files.split(',')
elif not train_split and config.dataset_validation_files:
dataset_files = config.dataset_validation_files.split(',')
elif config.dataset:
dataset_dir = os.path.join(config.data_dir, config.dataset)
prefix = config.dataset_train_file_prefix if train_split else config.dataset_validation_file_prefix
for dataset_file in glob.glob(os.path.join(dataset_dir, "*.*")):
if self.is_file_type_supported(dataset_file) and os.path.basename(dataset_file).startswith(prefix):
dataset_files.append(dataset_file)
logger.info(f"Found {len(dataset_files)} files in {dataset_dir} with prefix '{prefix}'")
if dataset_files:
return sorted(dataset_files)
elif train_split:
raise Exception('Training dataset files not found!')
else:
return []
def is_file_type_supported(self, dataset_file):
return dataset_file.endswith('.bin') or dataset_file.endswith('.pt') or dataset_file.endswith('.alm')
def load_next_dataset(self):
self.data = None
gc.collect()
for ds_file in self.dataset_files:
if ds_file not in self.processed_files:
if self.load_dataset_file(ds_file):
return True
return False
def load_dataset_file(self, load_dataset_file):
self.processed_files.append(load_dataset_file)
new_data = None
if load_dataset_file.endswith('.bin'):
# assert self.training_type == 'pre', 'NumPy format is supported only for pre-training'
step_size = self.world_size * self.sample_size
new_data = torch.from_numpy(np.fromfile(load_dataset_file, dtype=np.uint16).astype(np.int16))
if step_size > len(new_data):
logger.warning(
f"Dataset file {load_dataset_file} does not have enough data and will be ignored. "
f"Expected at least {step_size} tokens but found only {len(new_data)}"
)
return False
new_data = self.align_and_transform_continuous_data_to_samples(new_data, step_size)
new_data = self.limit_samples_to_rank(new_data)
elif load_dataset_file.endswith('.pt'):
assert self.training_type != 'dpo', 'DPO training only supports the ALM format'
new_data = torch.load(load_dataset_file, map_location='cpu')
if isinstance(new_data, torch.Tensor):
step_size = self.world_size * self.sample_size
if step_size > len(new_data):
logger.warning(
f"Dataset file {load_dataset_file} does not have enough data and will be ignored. "
f"Expected at least {step_size} tokens but found only {len(new_data)}"
)
return False
new_data = self.align_and_transform_continuous_data_to_samples(new_data, step_size)
new_data = self.limit_samples_to_rank(new_data)
else:
new_data = self.align_and_limit_to_rank(new_data, load_dataset_file)
if new_data:
self.pad_or_truncate_to_block_size(new_data)
elif load_dataset_file.endswith('.alm'):
new_data = joblib.load(load_dataset_file)
new_data = self.align_and_limit_to_rank(new_data, load_dataset_file)
if new_data:
self.data = new_data
self.data_in_alm_format = load_dataset_file.endswith('.alm')
logger.info(f"New dataset file {load_dataset_file} loaded. Processed files: {len(self.processed_files)}")
gc.collect()
return True
else:
return False
def align_and_limit_to_rank(self, new_data, load_dataset_file):
if isinstance(new_data, list):
if self.world_size > len(new_data):
logger.warning(
f"Dataset file {load_dataset_file} does not have enough data and will be ignored. "
f"Expected at least {self.world_size} samples but found only {len(new_data)}"
)
return None
new_data = self.align_data_to_step_size(new_data, self.world_size)
new_data = self.limit_samples_to_rank(new_data)
else:
logger.info(f"Unsupported format of {load_dataset_file}!")
new_data = None
return new_data
def align_data_to_step_size(self, data, step_size):
target_length = ((len(data) + step_size - 1) // step_size) * step_size
padding_length = target_length - len(data)
if padding_length > 0:
pre_size = len(data)
if isinstance(data, list):
data.extend(data[:padding_length])
else:
# FIXME: this operation is highly inefficient - it duplicates data in memory
data = torch.concat((data, data[:padding_length]))
logger.info(f"Data aligned. Pre-alignment size: {pre_size}, "
f"post-alignment size: {len(data)}, "
f"padding added: {padding_length}")
return data
def align_and_transform_continuous_data_to_samples(self, data, step_size):
target_length = ((len(data) + step_size - 1) // step_size) * step_size
padding_length = target_length - len(data)
if padding_length > 0:
pre_size = len(data)
result = [data[i:i + self.sample_size] for i in range(0, (target_length - step_size), self.sample_size)]
data = torch.concat((data[(target_length - step_size):], data[:padding_length]))
result.extend([data[i:i + self.sample_size] for i in range(0, len(data), self.sample_size)])
logger.info(f"Continuous data aligned and transformed to {len(result)} samples. "
f"Pre-alignment size: {pre_size}, "
f"post-alignment size: {target_length}, "
f"padding added: {padding_length}")
return result
else:
return [data[i:i + self.sample_size] for i in range(0, len(data), self.sample_size)]
def pad_or_truncate_to_block_size(self, data):
"""
Adds padding to instructions to maintain a consistent input shape, avoiding recompilations.
This method ensures all instructions have a uniform length matching the block size.
By doing so, it prevents the need for frequent recompilations that occur due to
dynamic input shapes, enhancing computational efficiency and stability.
"""
for idx in range(len(data)):
if isinstance(data[idx], dict):
if 'input_ids' not in data[idx]:
raise Exception(f"'input_ids' field not found in sample! Available keys: {', '.join(data[idx].keys())}")
elif isinstance(data[idx]['input_ids'], np.ndarray):
data[idx]['input_ids'] = torch.from_numpy(data[idx]['input_ids'])
if 'target_ids' not in data[idx]:
data[idx]['target_ids'] = data[idx]['input_ids'][1:]
elif isinstance(data[idx]['target_ids'], np.ndarray):
data[idx]['target_ids'] = torch.from_numpy(data[idx]['target_ids'])
if self.weighted_loss:
if 'target_weights' not in data[idx]:
data[idx]['target_weights'] = torch.where(data[idx]['target_ids'] == self.ignore_index, 0, 1)
elif isinstance(data[idx]['target_weights'], np.ndarray):
data[idx]['target_weights'] = torch.from_numpy(data[idx]['target_weights'])
elif 'target_weights' in data[idx]:
del data[idx]['target_weights']
if len(data[idx]['input_ids']) >= self.sample_size: # block_size = sample_size - 1
data[idx]['input_ids'] = data[idx]['input_ids'][:self.sample_size-1]
elif self.pad_token_id >= 0 and len(data[idx]['input_ids']) < self.sample_size-1:
padding = self.sample_size - 1 - len(data[idx]['input_ids'])
data[idx]['input_ids'] = torch.cat([data[idx]['input_ids'], torch.full((padding,), self.ignore_index)], dim=0)
if len(data[idx]['target_ids']) >= self.sample_size:
data[idx]['target_ids'] = data[idx]['target_ids'][:self.sample_size-1]
elif self.pad_token_id >= 0 and len(data[idx]['target_ids']) < self.sample_size-1:
padding = self.sample_size - 1 - len(data[idx]['target_ids'])
data[idx]['target_ids'] = torch.cat([data[idx]['target_ids'], torch.full((padding,), self.ignore_index)], dim=0)
if self.weighted_loss:
if len(data[idx]['target_weights']) >= self.sample_size:
data[idx]['target_weights'] = data[idx]['target_weights'][:self.sample_size-1]
elif self.pad_token_id >= 0 and len(data[idx]['target_weights']) < self.sample_size-1:
padding = self.sample_size - 1 - len(data[idx]['target_weights'])
data[idx]['target_weights'] = torch.cat([data[idx]['target_weights'], torch.full((padding,), 0)], dim=0)
assert len(data[idx]['input_ids']) == len(data[idx]['target_ids'])
if self.weighted_loss:
assert len(data[idx]['input_ids']) == len(data[idx]['target_weights'])
else:
if len(data[idx]) > self.sample_size:
data[idx] = data[idx][:self.sample_size]
if self.pad_token_id >= 0:
if len(data[idx]) < self.sample_size:
padding = self.sample_size - len(data[idx])
data[idx] = torch.cat([data[idx], torch.full((padding,), self.ignore_index)], dim=0)
input_ids = data[idx][:-1]
target_ids = data[idx][1:]
target_weights = torch.where(target_ids == self.ignore_index, 0, 1)
input_ids = input_ids.masked_fill(input_ids == self.ignore_index, self.pad_token_id)
data[idx] = {'input_ids': input_ids, 'target_ids': target_ids, 'target_weights': target_weights}
def limit_samples_to_rank(self, samples):
return list(s for s in samples[self.rank::self.world_size]) if self.world_size > 1 else samples
def has_data(self):
return self.data and len(self.data) > 0
def prepare_alm_dpo_sample(self, sample):
result = {
'chosen_input_ids': torch.from_numpy(sample['chosen_input_ids']),
'chosen_target_ids': torch.from_numpy(sample['chosen_target_ids']),
'rejected_input_ids': torch.from_numpy(sample['rejected_input_ids']),
'rejected_target_ids': torch.from_numpy(sample['rejected_target_ids'])
}
if "reference_chosen_logps" in sample and "reference_rejected_logps" in sample:
result["reference_chosen_logps"] = torch.tensor(sample['reference_chosen_logps'])
result["reference_rejected_logps"] = torch.tensor(sample['reference_rejected_logps'])
if self.pad_token_id >= 0:
if len(result['chosen_input_ids']) < self.block_size:
result['chosen_input_ids'] = F.pad(result['chosen_input_ids'], (0, self.block_size - len(result['chosen_input_ids'])), value=self.pad_token_id)
if len(result['chosen_target_ids']) < self.block_size:
result['chosen_target_ids'] = F.pad(result['chosen_target_ids'], (0, self.block_size - len(result['chosen_target_ids'])), value=self.ignore_index)
if len(result['rejected_input_ids']) < self.block_size:
result['rejected_input_ids'] = F.pad(result['rejected_input_ids'], (0, self.block_size - len(result['rejected_input_ids'])), value=self.pad_token_id)
if len(result['rejected_target_ids']) < self.block_size:
result['rejected_target_ids'] = F.pad(result['rejected_target_ids'], (0, self.block_size - len(result['rejected_target_ids'])), value=self.ignore_index)
return result
def prepare_alm_sample(self, sample):
"""
Assumes input sample contains at least 'input_ids' and 'target_ids' fields.
When the weighted loss is active, 'target_weights' field is required.
When samples are packed, it is assumed that a list of sequence lengths will be available
in the "seq_lens" field. This information will be used to create the attention mask.
If pad_token_id is set in the configuration, it is assumed that the sample list
did not have padding and samples are of length up to block_size.
"""
if self.training_type == 'dpo':
return self.prepare_alm_dpo_sample(sample)
result = {
'input_ids': torch.from_numpy(sample['input_ids']),
'target_ids': torch.from_numpy(sample['target_ids'])
}
if self.weighted_loss:
if 'target_weights' in sample:
result['target_weights'] = torch.from_numpy(sample['target_weights'])
else:
result['target_weights'] = torch.where(result['target_ids'] == self.ignore_index, 0, 1)
if self.pad_token_id >= 0:
if len(result['input_ids']) < self.block_size:
result['input_ids'] = F.pad(result['input_ids'], (0, self.block_size - len(result['input_ids'])), value=self.pad_token_id)
if len(result['target_ids']) < self.block_size:
result['target_ids'] = F.pad(result['target_ids'], (0, self.block_size - len(result['target_ids'])), value=self.ignore_index)
if 'target_weights' in result and len(result['target_weights']) < self.block_size:
result['target_weights'] = F.pad(result['target_weights'], (0, self.block_size - len(result['target_weights'])), value=0)
if "seq_lens" in sample:
total_seq_len = 0
block_attn_masks = []
sample_input_pos = []
for seq_len in sample["seq_lens"]:
sample_input_pos.extend(list(range(seq_len)))
total_seq_len += seq_len
# append lower triangular matrix for causal mask
block_attn_masks.append(torch.tril(torch.ones(seq_len, seq_len, dtype=torch.bool)))
if total_seq_len < len(result['input_ids']):
new_pos = sample_input_pos[-1] + 1
num_pad = len(result['input_ids']) - total_seq_len
sample_input_pos.extend(list(range(new_pos, new_pos + num_pad)))
block_attn_masks.append(torch.eye(num_pad, num_pad, dtype=torch.bool))
result['input_pos'] = torch.tensor(sample_input_pos)
result['attn_mask'] = torch.block_diag(*block_attn_masks)
return result
def __len__(self):
""" Size of currently loaded dataset file """
return len(self.data) if self.data else 0
def __getitem__(self, idx):
result = None
if isinstance(idx, slice):
result = self.data[idx]
if self.data_in_alm_format:
result = list(self.prepare_alm_sample(s) for s in result)
elif idx < self.__len__():
result = self.data[idx]
if self.data_in_alm_format:
result = self.prepare_alm_sample(result)
return result
import datetime
import logging
import os
from allamo.configuration import AllamoConfiguration
logger = logging.getLogger()
def configure_logger(config: AllamoConfiguration = None, with_file_handler: bool = False):
log_formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger.setLevel(logging.INFO)
stream_handler = logging.StreamHandler()
stream_handler.setLevel(logging.INFO)
stream_handler.setFormatter(log_formatter)
logger.addHandler(stream_handler)
if with_file_handler:
run_timestamp_str = datetime.datetime.now().strftime('%Y%m%d%H%M%S')
log_file_path = os.path.join(config.out_dir, f'allamo-{run_timestamp_str}.log')
file_handler = logging.FileHandler(log_file_path)
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(log_formatter)
logger.addHandler(file_handler)
import torch
from allamo.configuration import AllamoConfiguration
from allamo.logging import logger
class AttentionVersion:
"""
Versions:
0 - eager
1 - SDPA
2 - FA2
"""
def __init__(self):
self.disable_flash_attn_2()
def configure(self, config: AllamoConfiguration):
if config.attention_implementation:
if config.attention_implementation == 'flash_attention_2':
self.enable_flash_attn_2()
elif config.attention_implementation == 'sdpa':
self.disable_flash_attn_2()
elif config.attention_implementation == 'eager':
self.force_eager()
def disable_flash_attn_2(self):
self.version = 1 if hasattr(torch.nn.functional, 'scaled_dot_product_attention') else 0
self.flash_attn_2_supports_window_size = False
def enable_flash_attn_2(self):
self.version = 2
self.flash_attn_2_supports_window_size = True
def force_eager(self):
self.version = 0
self.flash_attn_2_supports_window_size = False
def log_version(self, sliding_window):
if self.version == 2:
logger.info("Using Flash Attention 2")
if self.flash_attn_2_supports_window_size and sliding_window:
logger.info("Using sliding window")
elif self.version == 1:
logger.info("Using scaled_dot_product_attention")
elif self.version == 0:
logger.info("WARNING: using slow attention")
else:
raise Exception('Unsupported Flash Attention version!')
attention_version = AttentionVersion()
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment