Copyright 2020 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
## Our Modifications
We adapt the OPT training code to ColossalAI by leveraging Gemini and ZeRO DDP.
## Quick Start
You can launch training by using the following bash script
Implementation of the specific Transformer architecture from <ahref="https://ai.googleblog.com/2022/04/pathways-language-model-palm-scaling-to.html">PaLM - Scaling Language Modeling with Pathways</a>, in less than 200 lines of code.
This model is pretty much SOTA on everything language.
It obviously will not scale, but it is just for educational purposes. To elucidate the public how simple it all really is.
## Install
```bash
$ pip install PaLM-pytorch
```
## Usage
```python
importtorch
frompalm_pytorchimportPaLM
palm=PaLM(
num_tokens=20000,
dim=512,
depth=12,
heads=8,
dim_head=64,
)
tokens=torch.randint(0,20000,(1,2048))
logits=palm(tokens)# (1, 2048, 20000)
```
The PaLM 540B in the paper would be
```python
palm=PaLM(
num_tokens=256000,
dim=18432,
depth=118,
heads=48,
dim_head=256
)
```
## Test on Enwik8
```bash
$ python train.py
```
## Todo
- [ ] offer a Triton optimized version of PaLM, bringing in https://github.com/lucidrains/triton-transformer
## Citations
```bibtex
@article{chowdhery2022PaLM,
title={PaLM: Scaling Language Modeling with Pathways},