Commit 03204b84 authored by flyingdown's avatar flyingdown
Browse files
parents f8b650c8 8fc9b21f
# APEX
## 介绍
[Introduction](README_ORIGIN.md)
## 安装
### System Requirements
- Linux.
- Python 3.7, 3.8, 3.9
- (**推荐**) Upgrade pip
```
python3 -m pip install --upgrade pip #--user
```
### 使用pip安装(以dtk-23.04版本为例)
可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
### 使用源码安装
#### 编译环境准备(以dtk-23.04版本为例)
- 拉取 apex 代码
```
git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git
```
-[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接
```
cd /opt && ln -s dtk-23.04 dtk
```
- 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
- 导入环境变量以及安装必要依赖库
```bash
source /opt/dtk/env.sh
export PYTORCH_ROCM_ARCH="gfx906;gfx926"
MAX_JOBS=16
pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
```
#### 编译安装
- 执行编译命令
```shell
cd apex
CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel
pip install dist/apex*
```
# APEX
## 介绍
[Introduction](README_ORIGIN.md)
## 安装
### System Requirements
- Linux.
- Python 3.7, 3.8, 3.9
- (**推荐**) Upgrade pip
```
python3 -m pip install --upgrade pip #--user
```
### 使用pip安装(以dtk-23.04版本为例)
可以在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取最新的 apex Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install apex-0.1+git2d8b360.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
### 使用源码安装
#### 编译环境准备(以dtk-23.04版本为例)
- 拉取 apex 代码
```
git clone -b dtk-23.04 http://developer.hpccube.com/codes/aicomponent/apex.git
```
-[开发者社区](https://developer.hpccube.com/tool/#sdk) DCU Toolkit 中下载 DTK-23.04 解压至 /opt/ 路径下,并建立软链接
```
cd /opt && ln -s dtk-23.04 dtk
```
- 在光合[光合开发者社区](https://developer.hpccube.com/tool/#sdk) AI 生态包中获取对应的 pytorch Release 版本(需对应 DCU Toolkit 版本与 python 版本)
```bash
python3 -m pip install torch-1.13.1a0+git4c8a1fe.abi0.dtk2304-cp37-cp37m-linux_x86_64.whl
```
- 导入环境变量以及安装必要依赖库
```bash
source /opt/dtk/env.sh
export PYTORCH_ROCM_ARCH="gfx906;gfx926"
MAX_JOBS=16
pip3 install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
pip3 install wheel -i https://pypi.tuna.tsinghua.edu.cn/simple --trusted-host pypi.tuna.tsinghua.edu.cn
```
#### 编译安装
- 执行编译命令
```shell
cd apex
CXX=hipcc CC=hipcc python3 setup.py --cpp_ext --cuda_ext bdist_wheel
pip install dist/apex*
```
......@@ -124,15 +124,13 @@ python setup.py install
### To install using extensions enabled use the following command in apex folder:
```
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
python setup.py install --cpp_ext --cuda_ext
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### To install Apex on ROCm using ninja and without cloning the source
```
pip install ninja
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### Linux
For performance and full functionality, we recommend installing Apex with
......@@ -140,12 +138,15 @@ CUDA and C++ extensions via
```bash
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
```
Apex also supports a Python-only build via
```bash
pip install -v --disable-pip-version-check --no-cache-dir ./
pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
```
A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`.
......@@ -158,4 +159,4 @@ A Python-only build omits:
### [Experimental] Windows
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source
on your system. A Python-only build via `pip install -v --no-cache-dir .` is more likely to work.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment.
\ No newline at end of file
......@@ -95,9 +95,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Q Fwd
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
......@@ -118,12 +118,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
......@@ -144,7 +144,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -169,9 +169,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
......@@ -192,12 +192,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
......@@ -218,7 +218,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -300,9 +300,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -323,7 +323,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -348,9 +348,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -371,7 +371,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......@@ -467,9 +467,9 @@ std::vector<torch::Tensor> bwd_cuda(
if (use_fp16) {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -490,12 +490,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches_q,
......@@ -516,7 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -565,9 +565,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
} else {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -588,12 +588,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches_q,
......@@ -614,7 +614,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -728,9 +728,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
......@@ -751,12 +751,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
......@@ -777,12 +777,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
......@@ -803,12 +803,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
......@@ -829,7 +829,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
......@@ -878,9 +878,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
......@@ -901,12 +901,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
......@@ -927,12 +927,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
......@@ -953,12 +953,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
......@@ -979,7 +979,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
......@@ -121,9 +121,9 @@ std::vector<torch::Tensor> fwd_cuda(
if (use_fp16) {
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
......@@ -145,12 +145,12 @@ std::vector<torch::Tensor> fwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
......@@ -171,7 +171,7 @@ std::vector<torch::Tensor> fwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
......@@ -196,9 +196,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags);
} else {
// Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim,
batches_q,
embed_dim,
......@@ -220,12 +220,12 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim,
batches_kv,
embed_dim,
......@@ -246,7 +246,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
......@@ -329,9 +329,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -352,7 +352,7 @@ std::vector<torch::Tensor> fwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
// Matmul2
gemm_switch_fp32accum( a_layout_n,
......@@ -379,9 +379,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -402,7 +402,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
......@@ -535,9 +535,9 @@ std::vector<torch::Tensor> bwd_cuda(
if (use_fp16) {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -558,12 +558,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches_q,
......@@ -584,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -633,9 +633,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
} else {
// Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
embed_dim,
......@@ -656,12 +656,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches_q,
......@@ -682,7 +682,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -797,9 +797,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
......@@ -821,12 +821,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
......@@ -847,12 +847,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
......@@ -873,12 +873,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
......@@ -899,7 +899,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
// Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n,
......@@ -948,9 +948,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_q,
output_lin_q_dim,
......@@ -972,12 +972,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_q_dim,
batches_q,
......@@ -998,12 +998,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches_kv,
output_lin_kv_dim,
......@@ -1024,12 +1024,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_kv_dim,
batches_kv,
......@@ -1050,7 +1050,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
......@@ -1080,4 +1080,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace encdec_norm_add
} // end namespace multihead_attn
\ No newline at end of file
} // end namespace multihead_attn
......@@ -91,9 +91,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd
input_lin_results.copy_(input_biases);
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
......@@ -114,7 +114,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -139,9 +139,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
......@@ -162,7 +162,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -239,9 +239,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -262,7 +262,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -289,9 +289,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -312,7 +312,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......@@ -393,9 +393,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -416,12 +416,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -442,7 +442,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -553,9 +553,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -576,12 +576,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......@@ -602,7 +602,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......@@ -610,9 +610,9 @@ std::vector<torch::Tensor> bwd_cuda(
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -770,9 +770,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -793,12 +793,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......@@ -819,7 +819,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
......@@ -89,9 +89,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd
input_lin_results.copy_(input_biases);
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
......@@ -112,7 +112,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -137,9 +137,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
......@@ -160,7 +160,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -245,9 +245,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -268,7 +268,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -295,9 +295,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs.copy_(output_biases);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -318,7 +318,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......@@ -399,9 +399,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -422,12 +422,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -448,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -553,9 +553,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -576,12 +576,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......@@ -602,7 +602,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......@@ -610,9 +610,9 @@ std::vector<torch::Tensor> bwd_cuda(
return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads};
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1
......@@ -764,9 +764,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -787,12 +787,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......@@ -813,7 +813,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......@@ -826,4 +826,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace self
} // end namespace multihead_attn
\ No newline at end of file
} // end namespace multihead_attn
......@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
......@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
......@@ -290,9 +290,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -313,12 +313,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -339,7 +339,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -387,9 +387,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -410,12 +410,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -551,9 +551,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -574,12 +574,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......@@ -600,7 +600,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -648,9 +648,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -671,12 +671,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......
......@@ -107,9 +107,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
......@@ -131,7 +131,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -156,9 +156,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim,
batches,
embed_dim,
......@@ -180,7 +180,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -264,9 +264,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -287,7 +287,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -313,9 +313,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -336,7 +336,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
......@@ -452,9 +452,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -475,12 +475,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -501,7 +501,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -549,9 +549,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
......@@ -572,12 +572,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
......@@ -598,7 +598,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
......@@ -713,9 +713,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -737,12 +737,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......@@ -764,7 +764,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
......@@ -812,9 +812,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
......@@ -836,12 +836,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
......@@ -863,7 +863,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
flags)));
}
......@@ -889,4 +889,4 @@ std::vector<torch::Tensor> bwd_cuda(
} // end namespace rocblas_gemmex
} // end namespace self_norm_add
} // end namespace multihead_attn
\ No newline at end of file
} // end namespace multihead_attn
......@@ -7,6 +7,8 @@
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
......@@ -45,6 +47,52 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
......@@ -57,13 +105,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float fAlpha = alpha;
float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k,
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
......
......@@ -11,11 +11,21 @@
#include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
cublasHandle_t handle,
......@@ -32,33 +42,6 @@ cublasStatus_t gemm_bias(
const float* beta,
double* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
......@@ -79,7 +62,6 @@ cublasStatus_t gemm_bias(
ldc,
CUDA_R_64F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP32 Wrapper around cublas GEMMEx
......@@ -98,34 +80,6 @@ cublasStatus_t gemm_bias(
const float* beta,
float* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
......@@ -146,7 +100,6 @@ cublasStatus_t gemm_bias(
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP16 Tensor core wrapper around cublas GEMMEx
......@@ -165,11 +118,10 @@ cublasStatus_t gemm_bias(
const float* beta,
at::Half* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocblas_gemm_ex(
return cublasGemmEx(
handle,
transa,
transb,
......@@ -178,24 +130,19 @@ cublasStatus_t gemm_bias(
k,
/* alpha */ &h_alpha,
A,
rocblas_datatype_f16_r,
CUDA_R_16F,
lda,
B,
rocblas_datatype_f16_r,
CUDA_R_16F,
ldb,
/* beta */ &h_beta,
C,
rocblas_datatype_f16_r,
CUDA_R_16F,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard,
0,
0);
} else {
return rocblas_gemm_ex(
/* CUDA_R_32F */ CUDA_R_16F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
return cublasGemmEx(
handle,
transa,
transb,
......@@ -204,45 +151,18 @@ cublasStatus_t gemm_bias(
k,
alpha,
A,
rocblas_datatype_f16_r,
CUDA_R_16F,
lda,
B,
rocblas_datatype_f16_r,
CUDA_R_16F,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
CUDA_R_16F,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
}
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
}
......
......@@ -13,6 +13,8 @@
#include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
......@@ -60,6 +62,52 @@ __device__ __inline__ float sigmoid(float a) {
return (retf);
}
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm(
cublasHandle_t handle,
......@@ -78,10 +126,10 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
......@@ -102,7 +150,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
flag);
flag));
#else
return cublasGemmEx(
handle,
......@@ -145,10 +193,10 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
......@@ -169,7 +217,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
flag));
#else
return cublasGemmEx(
......@@ -216,10 +264,10 @@ cublasStatus_t mlp_gemm(
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocblas_gemm_ex(
handle,
transa,
transb,
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
......@@ -242,10 +290,10 @@ cublasStatus_t mlp_gemm(
0,
flag);
} else {
return rocblas_gemm_ex(
handle,
transa,
transb,
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
......
[build-system]
requires = [
"setuptools",
"wheel",
]
build-backend = "setuptools.build_meta"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment