Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
9319318d
Commit
9319318d
authored
Oct 29, 2021
by
hubertlu-tw
Browse files
Fix namespace for pybind11
Fix rocblas_gemmex namespace Fix namespace Clean up comments
parent
83181423
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
76 additions
and
1107 deletions
+76
-1107
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
...contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
+3
-3
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+5
-310
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+1
-1
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+51
-292
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
...trib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+6
-250
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
+3
-3
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+3
-244
No files found.
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cpp.cpp
View file @
9319318d
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec
{
namespace
encdec
{
namespace
rocblas_gemm
_
ex
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -151,6 +151,6 @@ std::vector<torch::Tensor> bwd(
...
@@ -151,6 +151,6 @@ std::vector<torch::Tensor> bwd(
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
rocblas_gemm
_
ex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
encdec
::
rocblas_gemmex
::
fwd
,
"Encdec Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
rocblas_gemm
_
ex
::
bwd
,
"Encdec Multihead Attention Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
encdec
::
rocblas_gemmex
::
bwd
,
"Encdec Multihead Attention Backward."
);
}
}
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
9319318d
...
@@ -9,7 +9,7 @@
...
@@ -9,7 +9,7 @@
#include <cuda.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
...
@@ -25,7 +25,7 @@ extern THCState *state;
...
@@ -25,7 +25,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
encdec
{
namespace
encdec
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -88,11 +88,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -88,11 +88,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd
// Input Linear Q Fwd
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -117,28 +115,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -117,28 +115,8 @@ std::vector<torch::Tensor> fwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_q_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_q_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Fwd
// Input Linear KV Fwd
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -163,28 +141,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -163,28 +141,8 @@ std::vector<torch::Tensor> fwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_kv_dim,
// batches_kv,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// k_lin_results_ptr,
// CUDA_R_16F,
// output_lin_kv_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -206,24 +164,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -206,24 +164,6 @@ std::vector<torch::Tensor> fwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim_q,
// batch_stride_q,
// beta,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -267,7 +207,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -267,7 +207,6 @@ std::vector<torch::Tensor> fwd_cuda(
}
}
// Matmul2
// Matmul2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -289,27 +228,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -289,27 +228,8 @@ std::vector<torch::Tensor> fwd_cuda(
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
// Output Linear
// Output Linear
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -334,29 +254,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -334,29 +254,6 @@ std::vector<torch::Tensor> fwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_lin_q_results
,
input_lin_q_results
,
...
@@ -435,11 +332,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -435,11 +332,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
// Output Linear Dgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -464,28 +358,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -464,28 +358,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
// Output Linear Wgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -510,28 +384,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -510,28 +384,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches_q,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -553,27 +407,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -553,27 +407,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim_kv,
// batch_stride_kv,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -595,24 +430,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -595,24 +430,6 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim_kv,
// batch_stride_kv,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -634,7 +451,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -634,7 +451,6 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -656,27 +472,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -656,27 +472,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim_q
,
lead_dim_q
,
batch_stride_q
,
batch_stride_q
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim_kv,
// batch_stride_kv,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim_q,
// batch_stride_q,
// attn_batches);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -698,27 +495,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -698,27 +495,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim_kv
,
lead_dim_kv
,
batch_stride_kv
,
batch_stride_kv
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim_q,
// batch_stride_q,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim_kv,
// batch_stride_kv,
// attn_batches);
// Input Linear Q Dgrad
// Input Linear Q Dgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -743,29 +521,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -743,29 +521,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_q,
// output_lin_q_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_q_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_q_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Q Wgrad
// Input Linear Q Wgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -790,28 +547,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -790,28 +547,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_q_dim,
// batches_q,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs_q.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_q_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_q_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Dgrad
// Input Linear KV Dgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -836,29 +573,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -836,29 +573,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches_kv,
// output_lin_kv_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(k_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_kv_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_kv_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Wgrad
// Input Linear KV Wgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -883,27 +599,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -883,27 +599,6 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_kv_dim,
// batches_kv,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs_kv.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(k_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_kv_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_kv_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_q_grads
,
input_q_grads
,
...
@@ -914,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -914,6 +609,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace encdec
}
// end namespace encdec
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
9319318d
...
@@ -692,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -692,6 +692,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace encdec_norm_add
}
// end namespace encdec_norm_add
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
9319318d
...
@@ -84,12 +84,10 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -84,12 +84,10 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
T
ORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
T
HCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
output_lin_dim
,
output_lin_dim
,
...
@@ -97,42 +95,23 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -97,42 +95,23 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
// a_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
// b_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
// c_type
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
// d_type
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
rocblas_datatype_f32_r
,
// compute_type
rocblas_datatype_f32_r
,
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
...
@@ -155,26 +134,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -155,26 +134,6 @@ std::vector<torch::Tensor> fwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta_zero,
// static_cast<half*>(bmm1_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -202,7 +161,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -202,7 +161,6 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches
*
q_seq_len
/
sequences
);
attn_batches
*
q_seq_len
/
sequences
);
}
}
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -225,30 +183,11 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -225,30 +183,11 @@ std::vector<torch::Tensor> fwd_cuda(
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta_zero,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear
// Output Linear
T
ORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
T
HCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -256,44 +195,22 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -256,44 +195,22 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
// a_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
// b_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
// c_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
rocblas_datatype_f32_r
,
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_lin_results
,
input_lin_results
,
...
@@ -362,12 +279,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -362,12 +279,9 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Dgrad
// Output Linear Dgrad
T
ORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
T
HCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -375,89 +289,50 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -375,89 +289,50 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
// a_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// b_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// c_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
rocblas_datatype_f32_r
,
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // TODO: CUBLAS_GEMM_DEFAULT_TENSOR_OP
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Wgrad
// Output Linear Wgrad
T
ORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
T
HCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
embed_dim
,
batches
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
rocblas_datatype_f16_r
,
// a_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// b_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// c_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
rocblas_datatype_f32_r
,
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul2 Dgrad1
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
...
@@ -480,26 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -480,26 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2 Dgrad2
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -517,29 +373,11 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -517,29 +373,11 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
v_lin_grads_ptr
,
v_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
...
@@ -555,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -555,7 +393,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
,
stream
);
stream
);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -578,26 +416,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -578,26 +416,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad2
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -615,32 +434,14 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -615,32 +434,14 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
beta
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: cublasGemmEx --> rocblas_gemm_ex (ok)
// Input Linear Dgrad
// Input Linear Dgrad
T
ORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
T
HCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
embed_dim
,
...
@@ -648,92 +449,50 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -648,92 +449,50 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
rocblas_datatype_f16_r
,
// a_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// b_type
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// c_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
rocblas_datatype_f32_r
,
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(input_lin_output_grads.data_ptr()),
// //static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Wgrad
// Input Linear Wgrad
T
ORCH_CUDABLAS_CHECK
(
rocblas_gemm_ex
(
handle
,
T
HCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
output_lin_dim
,
output_lin_dim
,
batches
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
// a_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
rocblas_datatype_f16_r
,
// b_type
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// c_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
rocblas_datatype_f16_r
,
embed_dim
,
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
rocblas_datatype_f32_r
,
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
// TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_grads
,
input_grads
,
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
View file @
9319318d
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self_bias
{
namespace
self_bias
{
namespace
rocblas_gemm
_
ex
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
...
@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
);
}
}
}
// end namespace rocblas_gemm
_
ex
}
// end namespace rocblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
rocblas_gemm
_
ex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
rocblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
rocblas_gemm
_
ex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
rocblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
9319318d
...
@@ -83,11 +83,9 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -83,11 +83,9 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
input_lin_results
.
copy_
(
input_biases
);
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -105,36 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -105,36 +103,15 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr
,
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
rocblas_datatype_f16_r
,
output_lin_dim
,
output_lin_dim
,
q_lin_results_ptr
,
//
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
//
rocblas_datatype_f16_r
,
output_lin_dim
,
//
output_lin_dim
,
rocblas_datatype_f32_r
,
rocblas_datatype_f32_r
,
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -156,24 +133,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -156,24 +133,7 @@ std::vector<torch::Tensor> fwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta_zero,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
if
(
pad_mask
==
nullptr
)
{
...
@@ -214,7 +174,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -214,7 +174,6 @@ std::vector<torch::Tensor> fwd_cuda(
}
}
// Matmul2
// Matmul2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -236,29 +195,10 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -236,29 +195,10 @@ std::vector<torch::Tensor> fwd_cuda(
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta_zero,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
outputs
.
copy_
(
output_biases
);
outputs
.
copy_
(
output_biases
);
// Output Linear
// Output Linear
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -283,28 +223,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -283,28 +223,6 @@ std::vector<torch::Tensor> fwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_lin_results
,
input_lin_results
,
...
@@ -372,11 +290,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -372,11 +290,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
// Output Linear Dgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -401,28 +316,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -401,28 +316,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
// Output Linear Wgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -447,29 +342,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -447,29 +342,9 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -491,27 +366,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -491,27 +366,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -533,24 +389,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -533,24 +389,6 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
...
@@ -565,7 +403,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -565,7 +403,6 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
*
q_seq_len
,
stream
);
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -587,27 +424,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -587,27 +424,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -629,26 +447,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -629,26 +447,7 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Input Linear Dgrad
// Input Linear Dgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -673,30 +472,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -673,30 +472,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(input_lin_output_grads.data_ptr()),
// //static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
// Input Linear Wgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -721,29 +498,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -721,29 +498,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_grads
,
input_grads
,
...
@@ -754,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -754,6 +510,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
apex/contrib/csrc/multihead_attn/self_multihead_attn_cpp.cpp
View file @
9319318d
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self
{
namespace
self
{
namespace
rocblas_gemm
_
ex
{
namespace
rocblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -126,7 +126,7 @@ std::vector<torch::Tensor> bwd(
...
@@ -126,7 +126,7 @@ std::vector<torch::Tensor> bwd(
}
// end namespace multihead_attn
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
rocblas_gemm
_
ex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self
::
rocblas_gemmex
::
fwd
,
"Self Multihead Attention Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
rocblas_gemm
_
ex
::
bwd
,
"Self Multihead Attention Backward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self
::
rocblas_gemmex
::
bwd
,
"Self Multihead Attention Backward."
);
}
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
9319318d
...
@@ -24,7 +24,7 @@ extern THCState *state;
...
@@ -24,7 +24,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
self
{
namespace
self
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
bool
use_time_mask
,
...
@@ -80,10 +80,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -80,10 +80,8 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
// Input Linear Fwd
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -108,28 +106,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -108,28 +106,8 @@ std::vector<torch::Tensor> fwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -151,24 +129,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -151,24 +129,6 @@ std::vector<torch::Tensor> fwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
...
@@ -212,7 +172,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -212,7 +172,6 @@ std::vector<torch::Tensor> fwd_cuda(
}
}
// Matmul2
// Matmul2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -234,27 +193,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -234,27 +193,8 @@ std::vector<torch::Tensor> fwd_cuda(
head_dim
*
attn_batches
,
head_dim
*
attn_batches
,
head_dim
,
head_dim
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
// Output Linear
// Output Linear
// TODO
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -279,27 +219,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -279,27 +219,6 @@ std::vector<torch::Tensor> fwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_lin_results
,
input_lin_results
,
...
@@ -367,11 +286,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -367,11 +286,8 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
char
b_layout_t
{
't'
};
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
// Output Linear Dgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -396,28 +312,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -396,28 +312,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
// Output Linear Wgrad
// TODO (OOK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -442,28 +338,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -442,28 +338,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul2 Dgrad1
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_t
,
a_layout_t
,
b_layout_n
,
b_layout_n
,
...
@@ -485,27 +361,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -485,27 +361,8 @@ std::vector<torch::Tensor> bwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -527,24 +384,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -527,24 +384,6 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
apex_masked_scale_cuda
<
at
::
Half
,
float
,
uint32_t
>
(
...
@@ -566,7 +405,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -566,7 +405,6 @@ std::vector<torch::Tensor> bwd_cuda(
assert
(
softmax_success
);
assert
(
softmax_success
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_n
,
b_layout_n
,
...
@@ -588,27 +426,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -588,27 +426,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Matmul1 Dgrad2
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
b_layout_t
,
b_layout_t
,
...
@@ -630,27 +449,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -630,27 +449,8 @@ std::vector<torch::Tensor> bwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
attn_batches
);
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Input Linear Dgrad
// Input Linear Dgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
...
@@ -675,28 +475,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -675,28 +475,8 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
// Input Linear Wgrad
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
CUBLAS_OP_T
,
...
@@ -721,27 +501,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -721,27 +501,6 @@ std::vector<torch::Tensor> bwd_cuda(
algo
,
algo
,
solution_index
,
solution_index
,
flags
));
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
return
{
input_grads
,
input_grads
,
...
@@ -750,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -750,6 +509,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
};
}
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace self
}
// end namespace multihead_attn
}
// end namespace multihead_attn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment