"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "5cb5759366f872189d90d4299c204bac6e0238bf"
Unverified Commit de4ef388 authored by Rick Ho's avatar Rick Ho Committed by GitHub
Browse files

Merge pull request #165 from laekov/parallel-doc

Document for process groups
parents a6d202a6 8bb58982
...@@ -75,7 +75,9 @@ the MLP layer by the `FMoE` layers. ...@@ -75,7 +75,9 @@ the MLP layer by the `FMoE` layers.
### Using FastMoE in Parallel ### Using FastMoE in Parallel
FastMoE supports both data parallel and model parallel. FastMoE supports multiple ways of parallel training. See [a comprehensive
document for parallelism](doc/parallelism) for details. Below shows the two
simplest ways of using FastMoE in parallel.
#### Data Parallel #### Data Parallel
...@@ -83,27 +85,28 @@ In FastMoE's data parallel mode, both the gate and the experts are replicated on ...@@ -83,27 +85,28 @@ In FastMoE's data parallel mode, both the gate and the experts are replicated on
The following figure shows the forward pass of a 3-expert MoE with 2-way data parallel. The following figure shows the forward pass of a 3-expert MoE with 2-way data parallel.
<p align="center"> <p align="center">
<img src="doc/fastmoe_data_parallel.png" width="600"> <img src="doc/parallelism/fastmoe_data_parallel.png" width="600">
</p> </p>
For data parallel, no extra coding is needed. FastMoE works seamlessly with PyTorch's `DataParallel` or `DistributedDataParallel`. For data parallel, no extra coding is needed. FastMoE works seamlessly with PyTorch's `DataParallel` or `DistributedDataParallel`.
The only drawback of data parallel is that the number of experts is constrained by each worker's memory. The only drawback of data parallel is that the number of experts is constrained by each worker's memory.
#### Model Parallel #### Expert Parallel (also called Model Parlallel in some previous versions)
In FastMoE's model parallel mode, the gate network is still replicated on each worker but In FastMoE's expert parallel mode, the gate network is still replicated on each worker but
experts are placed separately across workers. experts are placed separately across workers.
Thus, by introducing additional communication cost, FastMoE enjoys a large expert pool whose size is proportional to the number of workers. Thus, by introducing additional communication cost, FastMoE enjoys a large expert pool whose size is proportional to the number of workers.
The following figure shows the forward pass of a 6-expert MoE with 2-way model parallel. Note that experts 1-3 are located in worker 1 while experts 4-6 are located in worker 2. The following figure shows the forward pass of a 6-expert MoE with 2-way model parallel. Note that experts 1-3 are located in worker 1 while experts 4-6 are located in worker 2.
<p align="center"> <p align="center">
<img src="doc/fastmoe_model_parallel.png" width="600"> <img src="doc/parallelism/fastmoe_expert_parallel.png" width="600">
</p> </p>
FastMoE's model parallel requires sophiscated parallel strategies that neither PyTorch nor FastMoE's expert parallel requires sophiscated parallel strategies that neither
Megatron-LM provides. The `fmoe.DistributedGroupedDataParallel` module is PyTorch nor Megatron-LM provided when FastMoE was created. The
introduced to replace PyTorch's DDP module. `fmoe.DistributedGroupedDataParallel` module is introduced to replace PyTorch's
DDP module.
#### Faster Performance Features #### Faster Performance Features
......
Multi-Dimensional Parallelism Supported by FastMoE
===
_这篇文档懒得写中文版了. 在获得来自社区的贡献前, 请自行谷歌翻译._
FastMoE now supports almost every popular way to train models in parallel, and any combination of them.
Below shows all possible group of processes that a process may get involved.
Users can enable them by simply assigning communication groups in either FastMoE or external codebase that uses FastMoE.
![](parallelism.png)
#### Data Parallel
In a group of data-parallel processes, models, including the experts, are replicated across the processes.
To have experts replicated, first, assign `expert_dp_comm="dp"` at `mark_parallel_comm` function of an `FMoE` instance.
(The string `"dp"` can be replaced by another name if you wish).
Then, wrap the MoE module with `fmoe.distributed.DistributedGroupedDataParallel`,
and set `dp_group` in the constructor to the process group in PyTorch that you wish to perform data parallelism.
By default, the parameters are initially synchronized, unless disabled by `need_sync=False`.
Run `model.allreduce_params` every iteration after backward propagation.
![](fastmoe_data_parallel.png)
#### Model Parallel
In typical model parallelism (maybe called tensor-model parallelism), every single expert is split up.
FastMoE requires the external codebase to implement it by properly splitting the expert module that is provided to FastMoE.
An official example using Megatron-LM can be seen in our adapter.
The `hidden_hidden_size` of FastMoE's transformer module is divided by `k` which denotes the number of model-parallel processes.
In this way, each expert is split into `k` pieces.
Then, an `all-reduce` is performed over the feature matrix externally in the adapter, so that output of the experts is merged.
#### Expert Parallel (MoE Group and Slice Group)
In a group of expert parallel processes, each process maintains different experts.
Processes in an MoE group contain all experts, and in `moe_group`, the input feature maps on the processes are from different samples.
FastMoE performs `all-to-all` to exchange them, i.e. sending each feature vector to the processes that contain its selected experts.
![](fastmoe_expert_parallel.png)
`slice_group` is a way to adapt typical model parallel to expert parallel.
It assumes that the processes in the group have replicated input feature vectors.
So, each process selects part of the feature vectors (a slice) as input to the `moe_group`,
and perform `all-gather` after the expert-parallel NN operations to produce replicated output.
#### Pipeline Parallel
An MoE layer is a part of any stage.
The external codebase shall handle the communication across stages.
Notice that the `gate` module is replicated across all the process of the above three ways of intra-layer parallelism.
So, for the inter-layer parallelism, users should specify `gate_group` in `DistributedGroupedDataParallel` as all processes in the same stage.
#### Hybrid Parallel
Obviously, any combination of the above four ways of parallel training can be enabled by specifying proper communication groups for `FMoE` and `DistributedGroupedDataParallel`.
Refer to our [ATC'23 paper](https://www.usenix.org/conference/atc23/presentation/zhai) for studies on the optimal selection of hybrid parallelism.
...@@ -64,7 +64,8 @@ train(model, ...) ...@@ -64,7 +64,8 @@ train(model, ...)
### 分布式地使用 FastMoE ### 分布式地使用 FastMoE
FastMoE 支持数据并行和模型并行. FastMoE 支持并行方式. 详见[并行方式详细说明](doc/parallelism).
以下简单介绍两种最容易使用的并行方式.
#### 数据并行. #### 数据并行.
...@@ -73,29 +74,30 @@ FastMoE 支持数据并行和模型并行. ...@@ -73,29 +74,30 @@ FastMoE 支持数据并行和模型并行.
下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式. 下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式.
<p align="center"> <p align="center">
<img src="fastmoe_data_parallel.png" width="600"> <img src="parallelism/fastmoe_data_parallel.png" width="600">
</p> </p>
对于数据并行, 额外的代码是不需要的. FastMoE 与 PyTorch 的 `DataParallel` 对于数据并行, 额外的代码是不需要的. FastMoE 与 PyTorch 的 `DataParallel`
`DistributedDataParallel` 模块都可以无缝对接. 该方式唯一的问题是, `DistributedDataParallel` 模块都可以无缝对接. 该方式唯一的问题是,
专家的数量受到单个计算单元(如GPU)的内存大小限制. 专家的数量受到单个计算单元(如GPU)的内存大小限制.
#### 模型并行 #### 专家并行 (也曾被叫作模型并行)
在 FastMoE 的模型并行模式中, 门网络依然是复制地被放置在每个计算单元上的, 在 FastMoE 的专家并行模式中, 门网络依然是复制地被放置在每个计算单元上的,
但是专家网络被独立地分别放置在各个计算单元上. 因此, 通过引入额外的通信操作, 但是专家网络被独立地分别放置在各个计算单元上. 因此, 通过引入额外的通信操作,
FastMoE 可以允许更多的专家网络们同时被训练, FastMoE 可以允许更多的专家网络们同时被训练,
而其数量限制与计算单元的数量是正相关的. 而其数量限制与计算单元的数量是正相关的.
下图展示了一个有六个专家网络的模型被两路模型并行地训练. 下图展示了一个有六个专家网络的模型被两路专家并行地训练.
注意专家1-3被放置在第一个计算单元上, 而专家4-6被放置在第二个计算单元上. 注意专家1-3被放置在第一个计算单元上, 而专家4-6被放置在第二个计算单元上.
<p align="center"> <p align="center">
<img src="fastmoe_model_parallel.png" width="600"> <img src="parallelism/fastmoe_expert_parallel.png" width="600">
</p> </p>
FastMoE 的模型并行模式需要专门的并行策略, 而 PyTorch 和 Megatron-LM FastMoE 的专家并行模式需要专门的并行策略, 而 PyTorch 和 Megatron-LM
都不支持这样的策略. 因此, 需要使用 `fmoe.DistributedGroupedDataParallel` 都不支持这样的策略 (在我们创建 FastMoE 时). 因此, 需要使用
`fmoe.DistributedGroupedDataParallel`
模块来代替 PyTorch 的 DDP 模块. 模块来代替 PyTorch 的 DDP 模块.
### 如何训练得更快 ### 如何训练得更快
......
...@@ -97,6 +97,10 @@ class FMoE(nn.Module): ...@@ -97,6 +97,10 @@ class FMoE(nn.Module):
the output. For each worker, FMoE only computes the output of a certain the output. For each worker, FMoE only computes the output of a certain
slice of the input batch, and will all-gather the outputs after slice of the input batch, and will all-gather the outputs after
computation. computation.
* `mp_group` is a deprecated alias of `slice_group`
* `moe_group` stands for the group of process that performs expert
parallelism. The default value `None` means all processes. See the
parallelism document for more details of the groups.
* `top_k` stands for the number of experts each token is going to. * `top_k` stands for the number of experts each token is going to.
* `gate` is a gate class which can found in `fmoe.gates`. * `gate` is a gate class which can found in `fmoe.gates`.
* `expert` can be specified as a module class, it is used to generate * `expert` can be specified as a module class, it is used to generate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment