"server/text_generation_server/models/model.py" did not exist on "31d76e238df7654157ab1e372b7d57ef859daaa7"
Commit 03204b84 authored by flyingdown's avatar flyingdown
Browse files
parents f8b650c8 8fc9b21f
...@@ -124,15 +124,13 @@ python setup.py install ...@@ -124,15 +124,13 @@ python setup.py install
### To install using extensions enabled use the following command in apex folder: ### To install using extensions enabled use the following command in apex folder:
``` ```
# if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
python setup.py install --cpp_ext --cuda_ext python setup.py install --cpp_ext --cuda_ext
```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### To install Apex on ROCm using ninja and without cloning the source
```
pip install ninja
pip install -v --install-option="--cpp_ext" --install-option="--cuda_ext" 'git+https://github.com/ROCmSoftwarePlatform/apex.git'
``` ```
Note that using --cuda_ext flag to install Apex will also enable all the extensions supported on ROCm including "--distributed_adam", "--distributed_lamb", "--bnp", "--xentropy", "--deprecated_fused_adam", "--deprecated_fused_lamb", and "--fast_multihead_attn".
### Linux ### Linux
For performance and full functionality, we recommend installing Apex with For performance and full functionality, we recommend installing Apex with
...@@ -140,12 +138,15 @@ CUDA and C++ extensions via ...@@ -140,12 +138,15 @@ CUDA and C++ extensions via
```bash ```bash
git clone https://github.com/NVIDIA/apex git clone https://github.com/NVIDIA/apex
cd apex cd apex
pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key...
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
# otherwise
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --global-option="--cpp_ext" --global-option="--cuda_ext" ./
``` ```
Apex also supports a Python-only build via Apex also supports a Python-only build via
```bash ```bash
pip install -v --disable-pip-version-check --no-cache-dir ./ pip install -v --disable-pip-version-check --no-build-isolation --no-cache-dir ./
``` ```
A Python-only build omits: A Python-only build omits:
- Fused kernels required to use `apex.optimizers.FusedAdam`. - Fused kernels required to use `apex.optimizers.FusedAdam`.
......
...@@ -95,9 +95,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -95,9 +95,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Q Fwd // Input Linear Q Fwd
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -118,12 +118,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -118,12 +118,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
embed_dim, embed_dim,
...@@ -144,7 +144,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -144,7 +144,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -169,9 +169,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -169,9 +169,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches, attn_batches,
flags); flags);
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -192,12 +192,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -192,12 +192,12 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
embed_dim, embed_dim,
...@@ -218,7 +218,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -218,7 +218,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -300,9 +300,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -300,9 +300,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -323,7 +323,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -323,7 +323,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
b_layout_n, b_layout_n,
...@@ -348,9 +348,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -348,9 +348,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -371,7 +371,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -371,7 +371,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -467,9 +467,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -467,9 +467,9 @@ std::vector<torch::Tensor> bwd_cuda(
if (use_fp16) { if (use_fp16) {
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -490,12 +490,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -490,12 +490,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches_q, batches_q,
...@@ -516,7 +516,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -516,7 +516,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -565,9 +565,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -565,9 +565,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
} else { } else {
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -588,12 +588,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -588,12 +588,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches_q, batches_q,
...@@ -614,7 +614,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -614,7 +614,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -728,9 +728,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -728,9 +728,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
output_lin_q_dim, output_lin_q_dim,
...@@ -751,12 +751,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -751,12 +751,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
...@@ -777,12 +777,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -777,12 +777,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_kv, batches_kv,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -803,12 +803,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -803,12 +803,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
...@@ -829,7 +829,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -829,7 +829,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -878,9 +878,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -878,9 +878,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
output_lin_q_dim, output_lin_q_dim,
...@@ -901,12 +901,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -901,12 +901,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
...@@ -927,12 +927,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -927,12 +927,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_kv, batches_kv,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -953,12 +953,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -953,12 +953,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
...@@ -979,7 +979,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -979,7 +979,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); // TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
...@@ -121,9 +121,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -121,9 +121,9 @@ std::vector<torch::Tensor> fwd_cuda(
if (use_fp16) { if (use_fp16) {
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -145,12 +145,12 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -145,12 +145,12 @@ std::vector<torch::Tensor> fwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
embed_dim, embed_dim,
...@@ -171,7 +171,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -171,7 +171,7 @@ std::vector<torch::Tensor> fwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
b_layout_n, b_layout_n,
...@@ -196,9 +196,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -196,9 +196,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags); flags);
} else { } else {
// Input Linear Q Fwd // Input Linear Q Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -220,12 +220,12 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -220,12 +220,12 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Fwd // Input Linear KV Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
embed_dim, embed_dim,
...@@ -246,7 +246,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -246,7 +246,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
b_layout_n, b_layout_n,
...@@ -329,9 +329,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -329,9 +329,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -352,7 +352,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -352,7 +352,7 @@ std::vector<torch::Tensor> fwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
// Matmul2 // Matmul2
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -379,9 +379,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -379,9 +379,9 @@ std::vector<torch::Tensor> fwd_cuda(
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -402,7 +402,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -402,7 +402,7 @@ std::vector<torch::Tensor> fwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
...@@ -535,9 +535,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -535,9 +535,9 @@ std::vector<torch::Tensor> bwd_cuda(
if (use_fp16) { if (use_fp16) {
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -558,12 +558,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -558,12 +558,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches_q, batches_q,
...@@ -584,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -584,7 +584,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -633,9 +633,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -633,9 +633,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
} else { } else {
// Output Linear Dgrad // Output Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
embed_dim, embed_dim,
...@@ -656,12 +656,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -656,12 +656,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches_q, batches_q,
...@@ -682,7 +682,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -682,7 +682,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -797,9 +797,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -797,9 +797,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
output_lin_q_dim, output_lin_q_dim,
...@@ -821,12 +821,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -821,12 +821,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
...@@ -847,12 +847,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -847,12 +847,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_kv, batches_kv,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -873,12 +873,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -873,12 +873,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
...@@ -899,7 +899,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -899,7 +899,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
// Matmul1 Dgrad1 // Matmul1 Dgrad1
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
...@@ -948,9 +948,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -948,9 +948,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Q Dgrad // Input Linear Q Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_q, batches_q,
output_lin_q_dim, output_lin_q_dim,
...@@ -972,12 +972,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -972,12 +972,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Q Wgrad // Input Linear Q Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_q_dim, output_lin_q_dim,
batches_q, batches_q,
...@@ -998,12 +998,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -998,12 +998,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Dgrad // Input Linear KV Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches_kv, batches_kv,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -1024,12 +1024,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -1024,12 +1024,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear KV Wgrad // Input Linear KV Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_kv_dim, output_lin_kv_dim,
batches_kv, batches_kv,
...@@ -1050,7 +1050,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -1050,7 +1050,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
......
...@@ -91,9 +91,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -91,9 +91,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -114,7 +114,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -114,7 +114,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -139,9 +139,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -139,9 +139,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches, attn_batches,
flags); flags);
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -162,7 +162,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -162,7 +162,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -239,9 +239,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -239,9 +239,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -262,7 +262,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -262,7 +262,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
b_layout_n, b_layout_n,
...@@ -289,9 +289,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -289,9 +289,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -312,7 +312,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -312,7 +312,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -393,9 +393,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -393,9 +393,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -416,12 +416,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -416,12 +416,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -442,7 +442,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -442,7 +442,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -553,9 +553,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -553,9 +553,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -576,12 +576,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -576,12 +576,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -602,7 +602,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -602,7 +602,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -610,9 +610,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -610,9 +610,9 @@ std::vector<torch::Tensor> bwd_cuda(
return {input_grads, input_weight_grads, output_weight_grads, return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads}; input_bias_grads, output_bias_grads};
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -770,9 +770,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -770,9 +770,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -793,12 +793,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -793,12 +793,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -819,7 +819,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -819,7 +819,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
...@@ -89,9 +89,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -89,9 +89,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -112,7 +112,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -112,7 +112,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -137,9 +137,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -137,9 +137,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
attn_batches, attn_batches,
flags); flags);
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -160,7 +160,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -160,7 +160,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -245,9 +245,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -245,9 +245,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -268,7 +268,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -268,7 +268,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
b_layout_n, b_layout_n,
...@@ -295,9 +295,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -295,9 +295,9 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -318,7 +318,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads, ...@@ -318,7 +318,7 @@ fwd_cuda(bool use_time_mask, bool is_training, int heads,
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -399,9 +399,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -399,9 +399,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -422,12 +422,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -422,12 +422,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -448,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -448,7 +448,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -553,9 +553,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -553,9 +553,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -576,12 +576,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -576,12 +576,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -602,7 +602,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -602,7 +602,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
...@@ -610,9 +610,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -610,9 +610,9 @@ std::vector<torch::Tensor> bwd_cuda(
return {input_grads, input_weight_grads, output_weight_grads, return {input_grads, input_weight_grads, output_weight_grads,
input_bias_grads, output_bias_grads}; input_bias_grads, output_bias_grads};
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -633,12 +633,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -659,7 +659,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false); auto output_bias_grads = output_grads.view({-1, embed_dim}) .sum(0, false);
// MatMul2 Dgrad1 // MatMul2 Dgrad1
...@@ -764,9 +764,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -764,9 +764,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches, attn_batches,
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -787,12 +787,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -787,12 +787,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -813,7 +813,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -813,7 +813,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
......
...@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); //TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results, return {input_lin_results, softmax_results, dropout_results,
...@@ -290,9 +290,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -290,9 +290,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -313,12 +313,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -313,12 +313,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -339,7 +339,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -339,7 +339,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -387,9 +387,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -387,9 +387,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches, attn_batches,
flags); flags);
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -410,12 +410,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -410,12 +410,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -551,9 +551,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -551,9 +551,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -574,12 +574,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -574,12 +574,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -600,7 +600,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -600,7 +600,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
b_layout_n, b_layout_n,
...@@ -648,9 +648,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -648,9 +648,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -671,12 +671,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -671,12 +671,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
......
...@@ -107,9 +107,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -107,9 +107,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
// Input Linear Fwd // Input Linear Fwd
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -131,7 +131,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -131,7 +131,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -156,9 +156,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -156,9 +156,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
attn_batches, attn_batches,
flags); flags);
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
output_lin_dim, output_lin_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -180,7 +180,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -180,7 +180,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size) // MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -264,9 +264,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -264,9 +264,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -287,7 +287,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -287,7 +287,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
b_layout_n, b_layout_n,
...@@ -313,9 +313,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -313,9 +313,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags); flags);
// Output Linear // Output Linear
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -336,7 +336,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training, ...@@ -336,7 +336,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
...@@ -452,9 +452,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -452,9 +452,9 @@ std::vector<torch::Tensor> bwd_cuda(
// Output Linear Dgrad // Output Linear Dgrad
if (use_fp16) { if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -475,12 +475,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -475,12 +475,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -501,7 +501,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -501,7 +501,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -549,9 +549,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -549,9 +549,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches, attn_batches,
flags); flags);
} else { } else {
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
embed_dim, embed_dim,
...@@ -572,12 +572,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -572,12 +572,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Output Linear Wgrad // Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
embed_dim, embed_dim,
batches, batches,
...@@ -598,7 +598,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -598,7 +598,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// MatMul2 Dgrad1 // MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t, gemm_switch_fp32accum( a_layout_t,
...@@ -713,9 +713,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -713,9 +713,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -737,12 +737,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -737,12 +737,12 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -764,7 +764,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -764,7 +764,7 @@ std::vector<torch::Tensor> bwd_cuda(
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/, /* rocblas_datatype_f32_r */ rocblas_datatype_f16_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} else { } else {
gemm_switch_fp32accum( a_layout_n, gemm_switch_fp32accum( a_layout_n,
b_layout_n, b_layout_n,
...@@ -812,9 +812,9 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -812,9 +812,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags); flags);
// Input Linear Dgrad // Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim, embed_dim,
batches, batches,
output_lin_dim, output_lin_dim,
...@@ -836,12 +836,12 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -836,12 +836,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
// Input Linear Wgrad // Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
CUBLAS_OP_N, hipOperationToRocOperation(CUBLAS_OP_N),
CUBLAS_OP_T, hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim, embed_dim,
output_lin_dim, output_lin_dim,
batches, batches,
...@@ -863,7 +863,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -863,7 +863,7 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r /*compute_type*/, rocblas_datatype_f32_r /*compute_type*/,
rocblas_gemm_algo_standard /*algo*/, rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/, 0 /*solution_index*/,
flags)); flags)));
} }
......
...@@ -7,6 +7,8 @@ ...@@ -7,6 +7,8 @@
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
...@@ -45,6 +47,52 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -45,6 +47,52 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
} }
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
...@@ -57,13 +105,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, ...@@ -57,13 +105,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle, TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
opa, opb, (int)m, (int)n, (int)k, hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)); (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)));
} }
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
......
...@@ -11,11 +11,21 @@ ...@@ -11,11 +11,21 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h" #include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
#endif #endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias( cublasStatus_t gemm_bias(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -32,33 +42,6 @@ cublasStatus_t gemm_bias( ...@@ -32,33 +42,6 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
double* C, double* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -79,7 +62,6 @@ cublasStatus_t gemm_bias( ...@@ -79,7 +62,6 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_64F, CUDA_R_64F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP32 Wrapper around cublas GEMMEx // FP32 Wrapper around cublas GEMMEx
...@@ -98,34 +80,6 @@ cublasStatus_t gemm_bias( ...@@ -98,34 +80,6 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
float* C, float* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -146,7 +100,6 @@ cublasStatus_t gemm_bias( ...@@ -146,7 +100,6 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP16 Tensor core wrapper around cublas GEMMEx // FP16 Tensor core wrapper around cublas GEMMEx
...@@ -165,11 +118,10 @@ cublasStatus_t gemm_bias( ...@@ -165,11 +118,10 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
at::Half* C, at::Half* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) { if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha); half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta); half h_beta = __float2half(*beta);
return rocblas_gemm_ex( return cublasGemmEx(
handle, handle,
transa, transa,
transb, transb,
...@@ -178,50 +130,18 @@ cublasStatus_t gemm_bias( ...@@ -178,50 +130,18 @@ cublasStatus_t gemm_bias(
k, k,
/* alpha */ &h_alpha, /* alpha */ &h_alpha,
A, A,
rocblas_datatype_f16_r, CUDA_R_16F,
lda, lda,
B, B,
rocblas_datatype_f16_r, CUDA_R_16F,
ldb, ldb,
/* beta */ &h_beta, /* beta */ &h_beta,
C, C,
rocblas_datatype_f16_r, CUDA_R_16F,
ldc,
C,
rocblas_datatype_f16_r,
ldc, ldc,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, /* CUDA_R_32F */ CUDA_R_16F,
rocblas_gemm_algo_standard, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
0,
0);
} else { } else {
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
}
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -242,7 +162,7 @@ cublasStatus_t gemm_bias( ...@@ -242,7 +162,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif }
} }
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h" #include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
...@@ -60,6 +62,52 @@ __device__ __inline__ float sigmoid(float a) { ...@@ -60,6 +62,52 @@ __device__ __inline__ float sigmoid(float a) {
return (retf); return (retf);
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm( cublasStatus_t mlp_gemm(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -78,10 +126,10 @@ cublasStatus_t mlp_gemm( ...@@ -78,10 +126,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
handle, (rocblas_handle) handle,
transa, hipOperationToRocOperation(transa),
transb, hipOperationToRocOperation(transb),
m, m,
n, n,
k, k,
...@@ -102,7 +150,7 @@ cublasStatus_t mlp_gemm( ...@@ -102,7 +150,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r, rocblas_datatype_f64_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag));
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -145,10 +193,10 @@ cublasStatus_t mlp_gemm( ...@@ -145,10 +193,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex( return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
handle, (rocblas_handle) handle,
transa, hipOperationToRocOperation(transa),
transb, hipOperationToRocOperation(transb),
m, m,
n, n,
k, k,
...@@ -169,7 +217,7 @@ cublasStatus_t mlp_gemm( ...@@ -169,7 +217,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag); flag));
#else #else
return cublasGemmEx( return cublasGemmEx(
...@@ -216,10 +264,10 @@ cublasStatus_t mlp_gemm( ...@@ -216,10 +264,10 @@ cublasStatus_t mlp_gemm(
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) { if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha); half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta); half h_beta = __float2half(*beta);
return rocblas_gemm_ex( return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
handle, (rocblas_handle) handle,
transa, hipOperationToRocOperation(transa),
transb, hipOperationToRocOperation(transb),
m, m,
n, n,
k, k,
...@@ -242,10 +290,10 @@ cublasStatus_t mlp_gemm( ...@@ -242,10 +290,10 @@ cublasStatus_t mlp_gemm(
0, 0,
flag); flag);
} else { } else {
return rocblas_gemm_ex( return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
handle, (rocblas_handle) handle,
transa, hipOperationToRocOperation(transa),
transb, hipOperationToRocOperation(transb),
m, m,
n, n,
k, k,
......
[build-system]
requires = [
"setuptools",
"wheel",
]
build-backend = "setuptools.build_meta"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment