[Release note](doc/release-note.md)
| [中文文档](doc/readme-cn.md)
| [Slack workspace](https://join.slack.com/t/fastmoe/shared_invite/zt-mz0ai6ol-ggov75D62YsgHfzShw8KYw)
## Introduction
An easy-to-use but efficient implementation of the Mixture of Experts (MoE)
model for PyTorch.
## Installation
### Prerequisites
PyTorch with CUDA is required. The repository is currently tested with PyTorch
v1.8.0 and CUDA 10, with designed compatibility to older versions.
If the distributed expert feature is enabled, NCCL with P2P communication
support, typically versions `>=2.7.5`, is needed.
### Installing
FastMoE contains a set of PyTorch customized opearators, including both C and
Python components. Use `python setup.py install` to easily install and enjoy
using FastMoE for training.
The distributed expert feature is disabled by default. If you want to enable
it, pass environment variable `USE_NCCL=1` to the setup script.
Note that an extra NCCL developer package is needed, which has to be consistant
with your PyTorch's NCCL version, which can be inspected by running
`torch.cuda.nccl.version()`. The
[official PyTorch docker image](https://hub.docker.com/r/pytorch/pytorch) is
recommended, as the environment is well-setup there. Otherwise, you can access
the [download link of all NCCL
versions](https://developer.nvidia.com/nccl/nccl-legacy-downloads) to download
the NCCL package that is suitable for you.
## Usage
### FMoEfy a Transformer model
Transformer is currently one of the most popular models to be extended by MoE. Using
FastMoE, a Transformer-based model can be extended as MoE by an one-key plugin
shown as follow.
For example, when using [Megatron-LM](https://github.com/nvidia/megatron-lm),
using the following lines can help you easily scale up the MLP layers to
multiple experts.
```python
model = ...
from fmoe.megatron import fmoefy
model = fmoefy(model, num_experts=