README.md 2.36 KB
Newer Older
Rick Ho's avatar
Rick Ho committed
1
2
Fast MoE
===
Zhilin Yang's avatar
init  
Zhilin Yang committed
3

Rick Ho's avatar
Rick Ho committed
4
5
## Introduction

Rick Ho's avatar
Rick Ho committed
6
An easy-to-use but efficient implementation of the Mixture of Experts (MoE) 
Rick Ho's avatar
Rick Ho committed
7
8
9
10
model for PyTorch. 

## Installation

11
12
13
### Prerequisites

PyTorch with CUDA is required. The repository is currently tested with PyTorch
Rick Ho's avatar
Rick Ho committed
14
v1.8.0 and CUDA 10, with designed compatibility to older versions.
Rick Ho's avatar
Rick Ho committed
15

Rick Ho's avatar
Rick Ho committed
16
17
If the distributed expert feature is enabled, NCCL with P2P communication
support, typically versions `>=2.7.5`, is needed. 
18
19
20

### Installing

Rick Ho's avatar
Rick Ho committed
21
22
23
24
Fast MoE contains a set of PyTorch customized opearators, including both C and
Python components. Use `python setup.py install` to easily install and enjoy
using Fast MoE for training.

25
26
27
28
29
30
31
32
33
34
The distributed expert feature is disabled by default. If you want to disable
it, pass environment variable `USE_NCCL=1` to the setup script.

Note that an extra NCCL developer package is needed, which has to be consistant
with your PyTorch's NCCL version, which can be inspected by running
`torch.cuda.nccl.version()`. The [official PyTorch docker image]() is
recommended, as the environment is well-setup there. Otherwise, you can access
the [download link of all NCCL
versions](https://developer.nvidia.com/nccl/nccl-legacy-downloads) to download
the NCCL package that is suitable for you.
Rick Ho's avatar
Rick Ho committed
35

Rick Ho's avatar
Rick Ho committed
36
37
## Usage 

Rick Ho's avatar
fmoefy  
Rick Ho committed
38
39
40
41
42
43
### FMoEfy a transformer model

Transformer is currently the most popular model to be extended by MoE. Using
Fast MoE, a transformer-based model can be extended as MoE by an one-key plugin
shown as follow.

Rick Ho's avatar
Rick Ho committed
44
45
46
For example, when using [Megatron-LM](https://github.com/nvidia/megatron-lm),
using the following lines can help you easily scale up the MLP layers to
multiple experts.
Rick Ho's avatar
fmoefy  
Rick Ho committed
47
48

```python
Rick Ho's avatar
Rick Ho committed
49
50
model = ...

Rick Ho's avatar
fmoefy  
Rick Ho committed
51
52
from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=<number of experts per worker>)
Rick Ho's avatar
Rick Ho committed
53
54

train(model, ...)
Rick Ho's avatar
fmoefy  
Rick Ho committed
55
56
```

Rick Ho's avatar
Rick Ho committed
57
58
59
A detailed tutorial to _moefy_ Megatron-LM can be found
[here](examples/megatron).

Rick Ho's avatar
Rick Ho committed
60
61
### Using Fast MoE as a PyTorch module

Rick Ho's avatar
Rick Ho committed
62
63
64
An example MoE transformer model can be seen in the
[Transformer-XL](examples/transformer-xl) example. The easist way is to replace
the MLP layer by the `FMoE` layers.
Rick Ho's avatar
Rick Ho committed
65
66
67

### Using Fast MoE in Parallel

Rick Ho's avatar
Rick Ho committed
68
For data parallel, no extra coding is needed.
Rick Ho's avatar
Rick Ho committed
69

Rick Ho's avatar
Rick Ho committed
70
71
72
73
For expert parallel, in which experts are located separately across workers, 
which requires sophiscated data-parallel strategies that neither PyTorch nor
Megatron-LM provides. The `fmoe.DistributedGroupedDataParallel` module is
introduced to replace PyTorch's DDP module.