`apex.transformer` is a module which enables efficient large Transformer models at scale.
`apex.transformer.tensor_parallel` is based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s `megatron.mpu` module.
`apex.transformer.tensor_parallel` and `apex.transformer.pipeline_parallel` are both based on [NVIDIA/Megatron-LM](https://github.com/NVIDIA/Megatron-LM)'s module.
The former is based on `megatron.mpu` and the latter is on `megatron.schedules` and `megatron.p2p_communication`.
## Tensor Model Parallel (TP)
APEX's tensor model parallel utilities provides some `torch.nn.Module`'s, custom fused kernels, and PRNG state handling.
See Appendix B.2 of [Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism](https://arxiv.org/abs/1909.08053) for the details of
PRNG state handling.
## Pipeline Model Parallel (PP)
APEX's pipeline model parallel functions require models to have `.set_input_tensor` because
the input tensor for `.forward` method can be `None`.
The following is a really casual sketch of training script with apex pp.