Commit f8b650c8 authored by flyingdown's avatar flyingdown
Browse files

1.修改了readme

2.添加环境变量APEX_ROCBLAS_GEMM_ALLOW_HALF用于控制是否使用fp16r
3.添加dcu版本信息

whl包名修改

readme更新安装步骤
parent 2c6c0f28
# Introduction # APEX
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. ## 介绍
Some of the code here will be included in upstream Pytorch eventually.
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex) [Introduction](README_ORIGIN.md)
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides ## 安装
# Contents ### System Requirements
## 1. Amp: Automatic Mixed Precision - Linux.
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script. - Python 3.7, 3.8, 3.9
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html) - (**推荐**) Upgrade pip
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
[API Documentation](https://nvidia.github.io/apex/amp.html) ```
python3 -m pip install --upgrade pip #--user
```
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet) ### 使用pip安装(以dtk-23.04版本为例)
可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本)
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan) ```bash
python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs) ```
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
[API Documentation](https://nvidia.github.io/apex/parallel.html)
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
### Checkpointing
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore ### 使用源码安装
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) #### 编译环境准备(以dtk-23.04版本为例)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training - 拉取 apex 代码
...
```
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`. ```
git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git
```
# Installation -[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接
## Containers ```
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch. cd /opt && ln -s dtk-23.04 dtk
The containers come with all the custom extensions available at the moment. ```
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as: - 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本)
- how to pull a container ```bash
- how to run a pulled container python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
- release notes ```
## From Source - 导入环境变量以及安装必要依赖库
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch. ```bash
source /opt/dtk/env.sh
The latest stable release obtainable from https://pytorch.org should also work. export PYTORCH_ROCM_ARCH="gfx906;gfx926"
### Rocm MAX_JOBS=16
Apex on ROCm supports both python only build and extension build. pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
Note: Pytorch version recommended is >=1.5 for extension build. pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
```
### To install using python only build use the following command in apex folder:
```
python setup.py install
```
### To install using extensions enabled use the following command in apex folder:
```
python setup.py install --cpp_ext --cuda_ext
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### To install Apex on ROCm using ninja and without cloning the source #### 编译安装
```
pip install ninja
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
```
### Linux - 执行编译命令
For performance and full functionality, we recommend installing Apex with ```shell
CUDA and C++ extensions via cd apex
```bash CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel
git clone https://github.com/NVIDIA/apex pip install dist/apex*
cd apex ```
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
Apex also supports a Python-only build via
```bash
pip install -v --disable-pip-version-check --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### [Experimental] Windows
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
# APEX
## 安装
### System Requirements
- Linux.
- Python 3.7, 3.8, 3.9
- (**推荐**) Upgrade pip
```
python3 -m pip install --upgrade pip #--user
```
### 使用pip安装(以dtk-23.04版本为例)
可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
### 使用源码安装
#### 编译环境准备(以dtk-23.04版本为例)
- 拉取 apex 代码
```
git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git
```
-[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接
```
cd /opt && ln -s dtk-23.04 dtk
```
- 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
- 导入环境变量以及安装必要依赖库
```bash
source /opt/dtk/env.sh
export PYTORCH_ROCM_ARCH="gfx906;gfx926"
MAX_JOBS=16
sha=`git rev-parse HEAD`
sed -i "/version=/{s/\(.*=\)['\"]\(.*\)['\"]/\1'\2\+git${sha:0:7}\.abi0.dtk23.04'/}" setup.py
pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
```
#### 编译安装
- 执行编译命令
```shell
cd apex
CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel
pip install dist/apex*
```
# Introduction
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides
# Contents
## 1. Amp: Automatic Mixed Precision
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
[API Documentation](https://nvidia.github.io/apex/amp.html)
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
[API Documentation](https://nvidia.github.io/apex/parallel.html)
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
### Checkpointing
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
```
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
# Installation
## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment.
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
## From Source
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
The latest stable release obtainable from https://pytorch.org should also work.
### Rocm
Apex on ROCm supports both python only build and extension build.
Note: Pytorch version recommended is >=1.5 for extension build.
### To install using python only build use the following command in apex folder:
```
python setup.py install
```
### To install using extensions enabled use the following command in apex folder:
```
python setup.py install --cpp_ext --cuda_ext
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### To install Apex on ROCm using ninja and without cloning the source
```
pip install ninja
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
```
### Linux
For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
Apex also supports a Python-only build via
```bash
pip install -v --disable-pip-version-check --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### [Experimental] Windows
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
...@@ -49,3 +49,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int ...@@ -49,3 +49,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
) )
return False return False
return True return True
try:
from .version import version, git_hash, git_branch, dtk, abi, torch_version, dcu_version # noqa: F401
__version__, __dcu_version__ = version, dcu_version
except ImportError:
pass
...@@ -236,12 +236,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -236,12 +236,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim; const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim; const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len; const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0; const float alpha = 1.0;
// const float beta = 0.0; const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim)); const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0; const half h_alpha = 1.0;
const half beta = 0.0; const half h_beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim))); const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one // TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code // in my first attempt to create the code
...@@ -289,6 +289,104 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -289,6 +289,104 @@ std::vector<torch::Tensor> bwd_cuda(
#endif #endif
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
...@@ -309,7 +407,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -309,7 +407,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()), static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -335,7 +433,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -335,7 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()), static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -386,6 +484,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -386,6 +484,8 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches, attn_batches,
flags); flags);
}
// Apply Dropout Mask and Scale by Dropout Probability // Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()), static_cast<at::Half const*>(matmul2_grads.data_ptr()),
...@@ -404,6 +504,104 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -404,6 +504,104 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success); assert(softmax_success);
// Matmul1 Dgrad1 // Matmul1 Dgrad1
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
b_layout_n, b_layout_n,
head_dim, head_dim,
...@@ -470,7 +668,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -470,7 +668,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()), static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
...@@ -496,10 +694,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -496,10 +694,12 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()), static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
embed_dim, embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include "utils.h"
//#include "cutlass/cutlass.h" //#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h" //#include "cutlass/gemm/gemm.h"
...@@ -28,6 +29,8 @@ int32_t solution_index = 0; ...@@ -28,6 +29,8 @@ int32_t solution_index = 0;
rocblas_int flags = 0; rocblas_int flags = 0;
*/ */
static bool use_fp16 = parseEnvVarFlag("APEX_APEX_ROCBLAS_GEMM_ALLOW_HALF");
namespace { namespace {
cublasOperation_t convertTransToCublasOperation(char trans) { cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't') if (trans == 't')
...@@ -42,44 +45,44 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -42,44 +45,44 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
} }
} }
// void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
// cublasOperation_t opa = convertTransToCublasOperation(transa); cublasOperation_t opa = convertTransToCublasOperation(transa);
// cublasOperation_t opb = convertTransToCublasOperation(transb); cublasOperation_t opb = convertTransToCublasOperation(transb);
// cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
// cublasSetStream(handle, stream); cublasSetStream(handle, stream);
// float fAlpha = alpha; float fAlpha = alpha;
// float fBeta = beta; float fBeta = beta;
// //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
// opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
// (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
// b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
// (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
// d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
// (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)); (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
// } }
// void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
// auto stream = c10::cuda::getCurrentCUDAStream(); auto stream = c10::cuda::getCurrentCUDAStream();
// if ( (transa == 't') && (transb == 'n') ) { if ( (transa == 't') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 'n') ) { } else if ( (transa == 'n') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 't') ) { } else if ( (transa == 'n') && (transb == 't') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); } else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else { } else {
// AT_ASSERTM(false, "TransA and TransB are invalid"); AT_ASSERTM(false, "TransA and TransB are invalid");
// } }
// } }
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h"
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
...@@ -164,8 +166,9 @@ cublasStatus_t gemm_bias( ...@@ -164,8 +166,9 @@ cublasStatus_t gemm_bias(
at::Half* C, at::Half* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha); if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half hBeta = __float2half(*beta); half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocblas_gemm_ex( return rocblas_gemm_ex(
handle, handle,
transa, transa,
...@@ -173,14 +176,14 @@ cublasStatus_t gemm_bias( ...@@ -173,14 +176,14 @@ cublasStatus_t gemm_bias(
m, m,
n, n,
k, k,
/* alpha */ &hAlpha, /* alpha */ &h_alpha,
A, A,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
lda, lda,
B, B,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldb, ldb,
/* beta */ &hBeta, /* beta */ &h_beta,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
...@@ -191,6 +194,33 @@ cublasStatus_t gemm_bias( ...@@ -191,6 +194,33 @@ cublasStatus_t gemm_bias(
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
0); 0);
} else {
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
}
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
/* Includes, cuda */ /* Includes, cuda */
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h"
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
...@@ -23,6 +24,7 @@ ...@@ -23,6 +24,7 @@
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread #define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// move to a header later on // move to a header later on
#define ILP 4 #define ILP 4
template<typename T> template<typename T>
...@@ -211,8 +213,9 @@ cublasStatus_t mlp_gemm( ...@@ -211,8 +213,9 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha); if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half hBeta = __float2half(*beta); half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocblas_gemm_ex( return rocblas_gemm_ex(
handle, handle,
transa, transa,
...@@ -220,14 +223,14 @@ cublasStatus_t mlp_gemm( ...@@ -220,14 +223,14 @@ cublasStatus_t mlp_gemm(
m, m,
n, n,
k, k,
/* alpha */ &hAlpha, /* alpha */ &h_alpha,
A, A,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
lda, lda,
B, B,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldb, ldb,
/* beta */ &hBeta, /* beta */ &h_beta,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
...@@ -238,6 +241,33 @@ cublasStatus_t mlp_gemm( ...@@ -238,6 +241,33 @@ cublasStatus_t mlp_gemm(
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag);
} else {
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
}
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
......
#pragma once
#include <cstdlib>
#include <torch/extension.h>
inline bool parseEnvVarFlag(const char* envVarName) {
char* stringValue = std::getenv(envVarName);
if (stringValue != nullptr) {
int val;
try {
val = std::stoi(stringValue);
} catch (std::exception& e) {
TORCH_CHECK(false,
"Invalid value for environment variable: " + std::string(envVarName));
}
if (val == 1) {
return true;
} else if (val == 0) {
return false;
} else {
TORCH_CHECK(false,
"Invalid value for environment variable: " + std::string(envVarName));
}
}
return false;
}
\ No newline at end of file
import os
import subprocess
from pathlib import Path
import torch
ROOT_DIR = Path(__file__).parent.resolve()
def _run_cmd(cmd, shell=False):
try:
return subprocess.check_output(cmd, cwd=ROOT_DIR, stderr=subprocess.DEVNULL, shell=shell).decode("ascii").strip()
except Exception:
return None
def _get_version():
if os.path.exists(ROOT_DIR / "version.txt"):
with open(ROOT_DIR / "version.txt", "r") as f:
version = f.read().strip()
else:
version = '0.1'
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
return version
def _make_version_file(version, sha, abi, dtk, torch_version, branch):
sha = "Unknown" if sha is None else sha
torch_version = '.'.join(torch_version.split('.')[:2])
dcu_version = f"{version}+{sha}.abi{abi}.dtk{dtk}.torch{torch_version}"
version_path = ROOT_DIR / "apex" / "version.py"
with open(version_path, "w") as f:
f.write(f"version = '{version}'\n")
f.write(f"git_hash = '{sha}'\n")
f.write(f"git_branch = '{branch}'\n")
f.write(f"abi = 'abi{abi}'\n")
f.write(f"dtk = '{dtk}'\n")
f.write(f"torch_version = '{torch_version}'\n")
f.write(f"dcu_version = '{dcu_version}'\n")
return dcu_version
def _get_pytorch_version():
if "PYTORCH_VERSION" in os.environ:
return f"{os.environ['PYTORCH_VERSION']}"
return torch.__version__
def get_version(ROCM_HOME):
sha = _run_cmd(["git", "rev-parse", "HEAD"])
sha = sha[:7]
branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
print("-- Git branch:", branch)
print("-- Git SHA:", sha)
print("-- Git tag:", tag)
torch_version = _get_pytorch_version()
print("-- PyTorch:", torch_version)
version = _get_version()
print("-- Building version", version)
abi = _run_cmd(["echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI | awk '{print $3}'"], shell=True)
print("-- _GLIBCXX_USE_CXX11_ABI:", abi)
dtk = _run_cmd(["cat", os.path.join(ROCM_HOME, '.info/rocm_version')])
dtk = ''.join(dtk.split('.')[:2])
print("-- DTK:", dtk)
return _make_version_file(version, sha, abi, dtk, torch_version, branch)
\ No newline at end of file
...@@ -7,6 +7,10 @@ import sys ...@@ -7,6 +7,10 @@ import sys
import warnings import warnings
import os import os
from get_version import get_version
dcu_version = get_version(ROCM_HOME)
# ninja build does not work unless include_dirs are abs path # ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__)) this_dir = os.path.dirname(os.path.abspath(__file__))
torch_dir = torch.__path__[0] torch_dir = torch.__path__[0]
...@@ -671,7 +675,7 @@ if "--cuda_ext" in sys.argv: ...@@ -671,7 +675,7 @@ if "--cuda_ext" in sys.argv:
setup( setup(
name="apex", name="apex",
version="0.1", version=dcu_version,
packages=find_packages( packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",) exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",)
), ),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment