Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ba0e5fa5
Commit
ba0e5fa5
authored
Oct 28, 2021
by
hubertlu-tw
Browse files
Hipify self_multihead_attn_bias_additive_mask.
parent
c3ec9351
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
352 additions
and
69 deletions
+352
-69
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
...ihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+348
-65
No files found.
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cpp.cpp
View file @
ba0e5fa5
...
...
@@ -4,7 +4,7 @@
namespace
multihead_attn
{
namespace
self_bias_additive_mask
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -132,12 +132,12 @@ std::vector<torch::Tensor> bwd(
);
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias_additive_mask
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias_additive_mask
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias_additive_mask
::
ro
cblas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias_additive_mask
::
ro
cblas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
ba0e5fa5
#include <vector>
#include <math.h>
#include <iostream>
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h"
#include "softmax.h"
...
...
@@ -21,7 +21,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
self_bias_additive_mask
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -48,8 +48,8 @@ std::vector<torch::Tensor> fwd_cuda(
const
int
batch_stride
=
3
*
head_dim
;
const
int
dropout_elems
=
attn_batches
*
q_seq_len
*
k_seq_len
;
const
float
alpha
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
beta_zero
=
0.0
;
const
float
beta_one
=
1.0
;
const
float
scale
=
1.0
/
sqrt
(
static_cast
<
float
>
(
head_dim
));
// There is no reason to use more than one stream as every kernel is
...
...
@@ -81,11 +81,12 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
T
HCublasCheck
(
cu
blas
G
emm
E
x
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -93,18 +94,42 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// a_type
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// b_type
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// c_type
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
q_lin_results_ptr
,
rocblas_datatype_f16_r
,
// d_type
output_lin_dim
,
rocblas_datatype_f32_r
,
// compute_type
algo
,
solution_index
,
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
...
...
@@ -123,7 +148,31 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta_zero,
// static_cast<half*>(bmm1_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
is_training
)
{
...
...
@@ -150,6 +199,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches
*
q_seq_len
/
sequences
);
}
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
...
...
@@ -168,12 +218,34 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta_zero,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
outputs
.
copy_
(
output_biases
);
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear
T
HCublasCheck
(
cu
blas
G
emm
E
x
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -181,20 +253,44 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// a_type
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// b_type
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// c_type
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
algo
,
solution_index
,
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
...
...
@@ -263,11 +359,12 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
// TODO: CUBLAS_TENSOR_OP_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Dgrad
T
HCublasCheck
(
cu
blas
G
emm
E
x
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -275,39 +372,89 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// a_type
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// b_type
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// c_type
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
algo
,
solution_index
,
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP)); // TODO: CUBLAS_GEMM_DEFAULT_TENSOR_OP
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Output Linear Wgrad
T
HCublasCheck
(
cu
blas
G
emm
E
x
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
embed_dim
,
batches
,
embed_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// a_type
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// b_type
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// c_type
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
// compute_type
algo
,
solution_index
,
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_t
,
...
...
@@ -326,8 +473,30 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul2 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
...
...
@@ -345,8 +514,29 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
...
...
@@ -362,7 +552,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
stream
);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
a_layout_n
,
...
...
@@ -381,8 +571,30 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: no matching function for call to "gemm_switch_fp32accum" (OK)
// Matmul1 Dgrad2
gemm_switch_fp32accum
(
state
,
a_layout_n
,
...
...
@@ -400,10 +612,32 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// TODO: cublasGemmEx --> rocblas_gemm_ex (ok)
// Input Linear Dgrad
T
HCublasCheck
(
cu
blas
G
emm
E
x
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -411,43 +645,92 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// a_type
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// b_type
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// c_type
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
embed_dim
,
rocblas_datatype_f32_r
,
// compute_type
algo
,
solution_index
,
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(input_lin_output_grads.data_ptr()),
// //static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO: cublasGemmEx --> rocblas_gemm_ex (OK)
// Input Linear Wgrad
T
HCublasCheck
(
cu
blas
G
emm
E
x
(
handle
,
T
ORCH_CUDABLAS_CHECK
(
roc
blas
_g
emm
_e
x
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
output_lin_dim
,
batches
,
embed_dim
,
output_lin_dim
,
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// a_type
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
output_lin_dim
,
rocblas_datatype_f16_r
,
// b_type
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
// c_type
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
// d_type
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
// compute_type
algo
,
solution_index
,
flags
));
// TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
// TODO: CUBLAS_DEFAULT_MATH (https://github.com/ROCmSoftwarePlatform/apex/commit/1fd257e2cd777f1ef7df37590f6dc6b2a73cc518) (ok)
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
...
...
@@ -458,6 +741,6 @@ std::vector<torch::Tensor> bwd_cuda(
};
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemmex
}
// end namespace self
}
// end namespace multihead_attn
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment