Commit f49ddd4c authored by flyingdown's avatar flyingdown
Browse files

Merge branch 'develop' into 'dtk-23.04'

1.修改了readme

See merge request dcutoolkit/deeplearing/apex!2
parents 2c6c0f28 f8b650c8
# Introduction
# APEX
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
## 介绍
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
[Introduction](README_ORIGIN.md)
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides
## 安装
# Contents
### System Requirements
## 1. Amp: Automatic Mixed Precision
- Linux.
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
- Python 3.7, 3.8, 3.9
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
- (**推荐**) Upgrade pip
[API Documentation](https://nvidia.github.io/apex/amp.html)
```
python3 -m pip install --upgrade pip #--user
```
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
[API Documentation](https://nvidia.github.io/apex/parallel.html)
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
### Checkpointing
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
### 使用pip安装(以dtk-23.04版本为例)
可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
### 使用源码安装
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
#### 编译环境准备(以dtk-23.04版本为例)
# Continue training
...
```
- 拉取 apex 代码
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
```
git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git
```
# Installation
-[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接
## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment.
```
cd /opt && ln -s dtk-23.04 dtk
```
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
- 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
## From Source
- 导入环境变量以及安装必要依赖库
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
```bash
source /opt/dtk/env.sh
The latest stable release obtainable from https://pytorch.org should also work.
export PYTORCH_ROCM_ARCH="gfx906;gfx926"
### Rocm
Apex on ROCm supports both python only build and extension build.
Note: Pytorch version recommended is >=1.5 for extension build.
MAX_JOBS=16
pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
```
### To install using python only build use the following command in apex folder:
```
python setup.py install
```
### To install using extensions enabled use the following command in apex folder:
```
python setup.py install --cpp_ext --cuda_ext
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### To install Apex on ROCm using ninja and without cloning the source
```
pip install ninja
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
```
#### 编译安装
### Linux
For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
- 执行编译命令
```shell
cd apex
CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel
pip install dist/apex*
```
Apex also supports a Python-only build via
```bash
pip install -v --disable-pip-version-check --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### [Experimental] Windows
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
# APEX
## 安装
### System Requirements
- Linux.
- Python 3.7, 3.8, 3.9
- (**推荐**) Upgrade pip
```
python3 -m pip install --upgrade pip #--user
```
### 使用pip安装(以dtk-23.04版本为例)
可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
### 使用源码安装
#### 编译环境准备(以dtk-23.04版本为例)
- 拉取 apex 代码
```
git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git
```
-[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接
```
cd /opt && ln -s dtk-23.04 dtk
```
- 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
- 导入环境变量以及安装必要依赖库
```bash
source /opt/dtk/env.sh
export PYTORCH_ROCM_ARCH="gfx906;gfx926"
MAX_JOBS=16
sha=`git rev-parse HEAD`
sed -i "/version=/{s/\(.*=\)['\"]\(.*\)['\"]/\1'\2\+git${sha:0:7}\.abi0.dtk23.04'/}" setup.py
pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
```
#### 编译安装
- 执行编译命令
```shell
cd apex
CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel
pip install dist/apex*
```
# Introduction
This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch.
Some of the code here will be included in upstream Pytorch eventually.
The intent of Apex is to make up-to-date utilities available to users as quickly as possible.
## Full API Documentation: [https://nvidia.github.io/apex](https://nvidia.github.io/apex)
## [GTC 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/GTC_2019) and [Pytorch DevCon 2019](https://github.com/mcarilli/mixed_precision_references/tree/master/Pytorch_Devcon_2019) Slides
# Contents
## 1. Amp: Automatic Mixed Precision
`apex.amp` is a tool to enable mixed precision training by changing only 3 lines of your script.
Users can easily experiment with different pure and mixed precision training modes by supplying
different flags to `amp.initialize`.
[Webinar introducing Amp](https://info.nvidia.com/webinar-mixed-precision-with-pytorch-reg-page.html)
(The flag `cast_batchnorm` has been renamed to `keep_batchnorm_fp32`).
[API Documentation](https://nvidia.github.io/apex/amp.html)
[Comprehensive Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
[DCGAN example coming soon...](https://github.com/NVIDIA/apex/tree/master/examples/dcgan)
[Moving to the new Amp API](https://nvidia.github.io/apex/amp.html#transition-guide-for-old-api-users) (for users of the deprecated "Amp" and "FP16_Optimizer" APIs)
## 2. Distributed Training
`apex.parallel.DistributedDataParallel` is a module wrapper, similar to
`torch.nn.parallel.DistributedDataParallel`. It enables convenient multiprocess distributed training,
optimized for NVIDIA's NCCL communication library.
[API Documentation](https://nvidia.github.io/apex/parallel.html)
[Python Source](https://github.com/NVIDIA/apex/tree/master/apex/parallel)
[Example/Walkthrough](https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed)
The [Imagenet example](https://github.com/NVIDIA/apex/tree/master/examples/imagenet)
shows use of `apex.parallel.DistributedDataParallel` along with `apex.amp`.
### Synchronized Batch Normalization
`apex.parallel.SyncBatchNorm` extends `torch.nn.modules.batchnorm._BatchNorm` to
support synchronized BN.
It allreduces stats across processes during multiprocess (DistributedDataParallel) training.
Synchronous BN has been used in cases where only a small
local minibatch can fit on each GPU.
Allreduced stats increase the effective batch size for the BN layer to the
global batch size across all processes (which, technically, is the correct
formulation).
Synchronous BN has been observed to improve converged accuracy in some of our research models.
### Checkpointing
To properly save and load your `amp` training, we introduce the `amp.state_dict()`, which contains all `loss_scalers` and their corresponding unskipped steps,
as well as `amp.load_state_dict()` to restore these attributes.
In order to get bitwise accuracy, we recommend the following workflow:
```python
# Initialization
opt_level = 'O1'
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
# Train your model
...
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
# Restore
model = ...
optimizer = ...
checkpoint = torch.load('amp_checkpoint.pt')
model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
amp.load_state_dict(checkpoint['amp'])
# Continue training
...
```
Note that we recommend restoring the model using the same `opt_level`. Also note that we recommend calling the `load_state_dict` methods after `amp.initialize`.
# Installation
## Containers
NVIDIA PyTorch Containers are available on NGC: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch.
The containers come with all the custom extensions available at the moment.
See [the NGC documentation](https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/index.html) for details such as:
- how to pull a container
- how to run a pulled container
- release notes
## From Source
To install Apex from source, we recommend using the nightly Pytorch obtainable from https://github.com/pytorch/pytorch.
The latest stable release obtainable from https://pytorch.org should also work.
### Rocm
Apex on ROCm supports both python only build and extension build.
Note: Pytorch version recommended is >=1.5 for extension build.
### To install using python only build use the following command in apex folder:
```
python setup.py install
```
### To install using extensions enabled use the following command in apex folder:
```
python setup.py install --cpp_ext --cuda_ext
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### To install Apex on ROCm using ninja and without cloning the source
```
pip install ninja
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
```
### Linux
For performance and full functionality, we recommend installing Apex with
CUDA and C++ extensions via
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
Apex also supports a Python-only build via
```bash
pip install -v --disable-pip-version-check --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
- Fused kernels required to use `apex.normalization.FusedLayerNorm` and `apex.normalization.FusedRMSNorm`.
- Fused kernels that improve the performance and numerical stability of `apex.parallel.SyncBatchNorm`.
- Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`.
`DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower.
### [Experimental] Windows
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
......@@ -49,3 +49,9 @@ def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int
)
return False
return True
try:
from .version import version, git_hash, git_branch, dtk, abi, torch_version, dcu_version # noqa: F401
__version__, __dcu_version__ = version, dcu_version
except ImportError:
pass
......@@ -42,12 +42,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -93,6 +94,81 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
h_beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -113,7 +189,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_q_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -139,7 +215,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_kv_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -166,6 +242,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
k_seq_len*q_seq_len,
attn_batches,
flags);
}
// Padded Softmax
bool softmax_success = false;
......@@ -199,6 +276,55 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
......@@ -242,10 +368,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_q_results,
......@@ -281,12 +409,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -337,6 +465,105 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
#endif
if (use_fp16) {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
} else {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
......@@ -358,7 +585,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -384,7 +611,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -434,6 +661,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride_kv,
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
......@@ -452,6 +680,157 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
if (use_fp16) {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -519,7 +898,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -545,7 +924,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -571,7 +950,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -597,10 +976,12 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
input_q_grads,
......
......@@ -51,12 +51,13 @@ std::vector<torch::Tensor> fwd_cuda(
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -118,6 +119,82 @@ std::vector<torch::Tensor> fwd_cuda(
1.0e-5, static_cast<const at::Half *>(lyr_nrm_gamma_weights.data_ptr()),
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
if (use_fp16) {
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_q_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs_q.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_q_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_kv_dim,
batches_kv,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
k_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_kv_dim,
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(q_lin_results_ptr),
lead_dim_q,
batch_stride_q,
h_beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
......@@ -140,7 +217,7 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_q_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -166,7 +243,7 @@ std::vector<torch::Tensor> fwd_cuda(
k_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_kv_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -193,6 +270,8 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches,
flags);
}
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
......@@ -224,6 +303,57 @@ std::vector<torch::Tensor> fwd_cuda(
(1.0f - dropout_prob));
}
if (use_fp16) {
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()),
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
// Matmul2
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -269,11 +399,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
// End-of-block Dropout-Add
if (is_training) {
apex_dropout_add_cuda<at::Half, float, uint32_t>(
......@@ -333,12 +465,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int batch_stride_q = head_dim;
const int batch_stride_kv = 2 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -401,6 +533,105 @@ std::vector<torch::Tensor> bwd_cuda(
total_tokens_q,
(1.0 / (1.0 - dropout_prob)));
if (use_fp16) {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches_q,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim_kv,
batch_stride_kv,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
v_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
} else {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
......@@ -422,7 +653,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -448,7 +679,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -499,6 +730,8 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
......@@ -516,6 +749,158 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len, attn_batches * q_seq_len);
assert(softmax_success);
if (use_fp16) {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim_kv,
batch_stride_kv,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
q_lin_grads_ptr,
lead_dim_q,
batch_stride_q,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim_q,
batch_stride_q,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
k_lin_grads_ptr,
lead_dim_kv,
batch_stride_kv,
attn_batches,
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_q,
output_lin_q_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
//static_cast<void*>(input_q_grads.data_ptr()),
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_q_dim,
batches_q,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_q.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_q_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches_kv,
output_lin_kv_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_kv_dim,
batches_kv,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs_kv.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(k_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_kv_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -584,7 +969,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -610,7 +995,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_q_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -636,7 +1021,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -662,11 +1047,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_kv_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half,float>(
static_cast<const half*>(input_lin_q_grads.data_ptr()),
......
......@@ -37,14 +37,14 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta_zero = 0.0;
// const float beta_one = 1.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta_zero = 0.0;
const half beta_one = 1.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta_zero = 0.0;
const half h_beta_one = 1.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -90,6 +90,55 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd
input_lin_results.copy_(input_biases);
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
h_beta_zero,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(bmm1_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -110,7 +159,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -138,6 +187,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches,
flags);
}
// Padded Softmax
bool softmax_success = false;
if (is_training) {
......@@ -162,6 +213,57 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
......@@ -207,10 +309,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, bmm1_results, dropout_results,
......@@ -235,12 +339,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -288,6 +392,224 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_recompute<half, half, float, false>(
static_cast<half*>(matmul2_grads.data_ptr()),
static_cast<half* const>(matmul2_grads.data_ptr()),
reinterpret_cast<half const*>(bmm1_results.data_ptr()),
reinterpret_cast<half const*>(pad_mask.data_ptr()),
static_cast<uint8_t const*>(dropout_mask.data_ptr()),
1.0/(1.0-dropout_prob),
k_seq_len,
k_seq_len,
attn_batches*q_seq_len/sequences,
attn_batches*q_seq_len,
stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -308,7 +630,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -334,7 +656,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -468,7 +790,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -494,7 +816,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -504,6 +826,8 @@ std::vector<torch::Tensor> bwd_cuda(
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
}
} // end namespace rocblas_gemmex
......
......@@ -36,14 +36,14 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta_zero = 0.0;
// const float beta_one = 1.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta_zero = 0.0;
const half beta_one = 1.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta_zero = 0.0;
const float beta_one = 1.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta_zero = 0.0;
const half h_beta_one = 1.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -88,6 +88,55 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd
input_lin_results.copy_(input_biases);
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
h_beta_zero,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -108,7 +157,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
q_lin_results_ptr,
rocblas_datatype_f16_r,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -136,6 +185,8 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
attn_batches,
flags);
}
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
......@@ -168,6 +219,57 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
k_seq_len,
k_seq_len*q_seq_len,
h_beta_zero,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta_one),
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
......@@ -213,10 +315,12 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
static_cast<void*>(outputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
......@@ -241,12 +345,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -294,6 +398,218 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
dispatch_masked_scale_softmax_backward_stream<half, half, float, false>(
static_cast<half *>(matmul2_grads.data_ptr()),
static_cast<half *>(matmul2_grads.data_ptr()),
reinterpret_cast<half const *>(softmax_results.data_ptr()),
static_cast<uint8_t const *>(dropout_mask.data_ptr()),
1.0 / (1.0 - dropout_prob), k_seq_len, k_seq_len,
attn_batches * q_seq_len, stream);
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(input_lin_output_grads.data_ptr()),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -314,7 +630,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -340,7 +656,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -468,7 +784,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -494,7 +810,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -504,6 +820,8 @@ std::vector<torch::Tensor> bwd_cuda(
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
}
}
} // end namespace rocblas_gemmex
......
......@@ -236,12 +236,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -289,6 +289,104 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -309,7 +407,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -335,7 +433,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -386,6 +484,8 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
......@@ -404,6 +504,104 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
......@@ -470,7 +668,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -496,10 +694,12 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
......
......@@ -40,12 +40,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// There is no reason to use more than one stream as every kernel is
// sequentially dependent
......@@ -106,6 +106,56 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<const at::Half *>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
q_lin_results_ptr,
rocblas_datatype_f16_r /*c_type*/,
output_lin_dim,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_scale,
static_cast<const half*>(k_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(q_lin_results_ptr),
lead_dim,
batch_stride,
h_beta,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(softmax_results_ptr),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
......@@ -127,7 +177,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
q_lin_results_ptr,
rocblas_datatype_f16_r /*d_type*/,
output_lin_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -155,6 +205,8 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches,
flags);
}
// Padded Softmax
bool softmax_success = false;
if (pad_mask == nullptr) {
......@@ -187,6 +239,56 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
}
// Matmul2
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
(is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
//static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<half*>(matmul2_results.data_ptr()),
head_dim*attn_batches,
head_dim,
attn_batches,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
......@@ -231,11 +333,13 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
static_cast<void*>(output_lin_results.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
// End-of-block Dropout-Add
if (is_training) {
......@@ -283,12 +387,12 @@ std::vector<torch::Tensor> bwd_cuda(
const int lead_dim = attn_batches * 3 * head_dim;
const int batch_stride = 3 * head_dim;
const int dropout_elems = attn_batches * q_seq_len * k_seq_len;
// const float alpha = 1.0;
// const float beta = 0.0;
// const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half alpha = 1.0;
const half beta = 0.0;
const half scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
const float alpha = 1.0;
const float beta = 0.0;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const half h_alpha = 1.0;
const half h_beta = 0.0;
const half h_scale = __float2half(1.0 / sqrt(static_cast<float>(head_dim)));
// TODO: Streams can be used in Backprop but I haven't added more than one
// in my first attempt to create the code
......@@ -347,6 +451,104 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(dropout_add_grads.data_ptr()),
rocblas_datatype_f16_r /*b_type*/,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
......@@ -367,7 +569,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -393,7 +595,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -444,6 +646,8 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
......@@ -462,6 +666,106 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&h_beta),
//static_cast<void*>(input_grads.data_ptr()),
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
//static_cast<const void*>(inputs.data_ptr()),
static_cast<const void*>(lyr_nrm_results.data_ptr()),
rocblas_datatype_f16_r /*a_type*/,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r /*b_type*/,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*c_type*/,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
......@@ -529,7 +833,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_lin_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
......@@ -556,11 +860,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r /*d_type*/,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient<half, float>(
static_cast<const half *>(input_lin_grads.data_ptr()),
......
......@@ -10,6 +10,7 @@
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "utils.h"
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
......@@ -28,6 +29,8 @@ int32_t solution_index = 0;
rocblas_int flags = 0;
*/
static bool use_fp16 = parseEnvVarFlag("APEX_APEX_ROCBLAS_GEMM_ALLOW_HALF");
namespace {
cublasOperation_t convertTransToCublasOperation(char trans) {
if (trans == 't')
......@@ -42,44 +45,44 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
// void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
// cublasOperation_t opa = convertTransToCublasOperation(transa);
// cublasOperation_t opb = convertTransToCublasOperation(transb);
// cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
// cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
// cublasSetStream(handle, stream);
// float fAlpha = alpha;
// float fBeta = beta;
// //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
// opa, opb, (int)m, (int)n, (int)k,
// (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
// b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
// (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
// d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
// (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
// }
// void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
// float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
// float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
// auto stream = c10::cuda::getCurrentCUDAStream();
// if ( (transa == 't') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 'n') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else if ( (transa == 'n') && (transb == 't') ) {
// if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
// } else {
// AT_ASSERTM(false, "TransA and TransB are invalid");
// }
// }
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
cublasOperation_t opa = convertTransToCublasOperation(transa);
cublasOperation_t opb = convertTransToCublasOperation(transb);
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
cublasSetStream(handle, stream);
float fAlpha = alpha;
float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
auto stream = c10::cuda::getCurrentCUDAStream();
if ( (transa == 't') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 'n') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else if ( (transa == 'n') && (transb == 't') ) {
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
else { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
} else {
AT_ASSERTM(false, "TransA and TransB are invalid");
}
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
half alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
......
......@@ -10,6 +10,8 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "utils.h"
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
......@@ -164,8 +166,9 @@ cublasStatus_t gemm_bias(
at::Half* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha);
half hBeta = __float2half(*beta);
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocblas_gemm_ex(
handle,
transa,
......@@ -173,14 +176,14 @@ cublasStatus_t gemm_bias(
m,
n,
k,
/* alpha */ &hAlpha,
/* alpha */ &h_alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
/* beta */ &hBeta,
/* beta */ &h_beta,
C,
rocblas_datatype_f16_r,
ldc,
......@@ -191,6 +194,33 @@ cublasStatus_t gemm_bias(
rocblas_gemm_algo_standard,
0,
0);
} else {
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
}
#else
return cublasGemmEx(
handle,
......
......@@ -11,6 +11,7 @@
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "utils.h"
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
......@@ -23,6 +24,7 @@
#define BIAS_RELU_RED_PER_THREAD 16 // backward minimal reduction length per thread
// move to a header later on
#define ILP 4
template<typename T>
......@@ -211,8 +213,9 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
half hAlpha = __float2half(*alpha);
half hBeta = __float2half(*beta);
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocblas_gemm_ex(
handle,
transa,
......@@ -220,14 +223,14 @@ cublasStatus_t mlp_gemm(
m,
n,
k,
/* alpha */ &hAlpha,
/* alpha */ &h_alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
/* beta */ &hBeta,
/* beta */ &h_beta,
C,
rocblas_datatype_f16_r,
ldc,
......@@ -238,6 +241,33 @@ cublasStatus_t mlp_gemm(
rocblas_gemm_algo_standard,
0,
flag);
} else {
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
}
#else
return cublasGemmEx(
handle,
......
#pragma once
#include <cstdlib>
#include <torch/extension.h>
inline bool parseEnvVarFlag(const char* envVarName) {
char* stringValue = std::getenv(envVarName);
if (stringValue != nullptr) {
int val;
try {
val = std::stoi(stringValue);
} catch (std::exception& e) {
TORCH_CHECK(false,
"Invalid value for environment variable: " + std::string(envVarName));
}
if (val == 1) {
return true;
} else if (val == 0) {
return false;
} else {
TORCH_CHECK(false,
"Invalid value for environment variable: " + std::string(envVarName));
}
}
return false;
}
\ No newline at end of file
import os
import subprocess
from pathlib import Path
import torch
ROOT_DIR = Path(__file__).parent.resolve()
def _run_cmd(cmd, shell=False):
try:
return subprocess.check_output(cmd, cwd=ROOT_DIR, stderr=subprocess.DEVNULL, shell=shell).decode("ascii").strip()
except Exception:
return None
def _get_version():
if os.path.exists(ROOT_DIR / "version.txt"):
with open(ROOT_DIR / "version.txt", "r") as f:
version = f.read().strip()
else:
version = '0.1'
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
return version
def _make_version_file(version, sha, abi, dtk, torch_version, branch):
sha = "Unknown" if sha is None else sha
torch_version = '.'.join(torch_version.split('.')[:2])
dcu_version = f"{version}+{sha}.abi{abi}.dtk{dtk}.torch{torch_version}"
version_path = ROOT_DIR / "apex" / "version.py"
with open(version_path, "w") as f:
f.write(f"version = '{version}'\n")
f.write(f"git_hash = '{sha}'\n")
f.write(f"git_branch = '{branch}'\n")
f.write(f"abi = 'abi{abi}'\n")
f.write(f"dtk = '{dtk}'\n")
f.write(f"torch_version = '{torch_version}'\n")
f.write(f"dcu_version = '{dcu_version}'\n")
return dcu_version
def _get_pytorch_version():
if "PYTORCH_VERSION" in os.environ:
return f"{os.environ['PYTORCH_VERSION']}"
return torch.__version__
def get_version(ROCM_HOME):
sha = _run_cmd(["git", "rev-parse", "HEAD"])
sha = sha[:7]
branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
print("-- Git branch:", branch)
print("-- Git SHA:", sha)
print("-- Git tag:", tag)
torch_version = _get_pytorch_version()
print("-- PyTorch:", torch_version)
version = _get_version()
print("-- Building version", version)
abi = _run_cmd(["echo '#include <string>' | gcc -x c++ -E -dM - | fgrep _GLIBCXX_USE_CXX11_ABI | awk '{print $3}'"], shell=True)
print("-- _GLIBCXX_USE_CXX11_ABI:", abi)
dtk = _run_cmd(["cat", os.path.join(ROCM_HOME, '.info/rocm_version')])
dtk = ''.join(dtk.split('.')[:2])
print("-- DTK:", dtk)
return _make_version_file(version, sha, abi, dtk, torch_version, branch)
\ No newline at end of file
......@@ -7,6 +7,10 @@ import sys
import warnings
import os
from get_version import get_version
dcu_version = get_version(ROCM_HOME)
# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))
torch_dir = torch.__path__[0]
......@@ -671,7 +675,7 @@ if "--cuda_ext" in sys.argv:
setup(
name="apex",
version="0.1",
version=dcu_version,
packages=find_packages(
exclude=("build", "csrc", "include", "tests", "dist", "docs", "tests", "examples", "apex.egg-info",)
),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment