Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
61416180
Commit
61416180
authored
Oct 28, 2021
by
hubertlu-tw
Browse files
Hipify self_multihead_attn_bias
Fix some spacing
parent
8bdbb502
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
384 additions
and
99 deletions
+384
-99
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
...trib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+334
-49
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+46
-46
No files found.
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
View file @
61416180
...
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
self_bias
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemm
_
ex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace self
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemm
_
ex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemm
_
ex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
61416180
...
...
@@ -21,7 +21,7 @@ extern THCState *state;
namespace
multihead_attn
{
namespace
self_bias
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemmex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -80,11 +80,12 @@ std::vector<torch::Tensor> fwd_cuda(
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
b_layout_n
{
'n'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// TODO (OK)
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
input_lin_results
.
copy_
(
input_biases
);
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
output_lin_dim
,
...
...
@@ -92,19 +93,45 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
q_lin_results_ptr
,
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
q_lin_results_ptr
,
//
rocblas_datatype_f16_r
,
//
output_lin_dim
,
//
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// output_lin_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// q_lin_results_ptr,
// CUDA_R_16F,
// output_lin_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
// TODO (OK)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
...
...
@@ -122,7 +149,28 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// scale,
// static_cast<const half*>(k_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(q_lin_results_ptr),
// lead_dim,
// batch_stride,
// beta_zero,
// static_cast<half*>(softmax_results_ptr),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Padded Softmax
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
...
...
@@ -163,6 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(
}
// Matmul2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
...
...
@@ -180,12 +229,34 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// (is_training) ? static_cast<const half*>(dropout_results.data_ptr()) : static_cast<const half*>(softmax_results.data_ptr()) ,
// k_seq_len,
// k_seq_len*q_seq_len,
// beta_zero,
// static_cast<half*>(matmul2_results.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// attn_batches);
outputs
.
copy_
(
output_biases
);
// Output Linear
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_T
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -193,20 +264,44 @@ std::vector<torch::Tensor> fwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta_one
),
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
static_cast
<
void
*>
(
outputs
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_T,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta_one),
// static_cast<void*>(outputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO1_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_lin_results
,
...
...
@@ -274,11 +369,12 @@ std::vector<torch::Tensor> bwd_cuda(
char
a_layout_t
{
't'
};
char
b_layout_n
{
'n'
};
char
b_layout_t
{
't'
};
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// TODO (OK)
//
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -286,19 +382,45 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// embed_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(output_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_lin_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -306,20 +428,45 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
output_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// embed_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(matmul2_results.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(output_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(output_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
output_bias_grads
=
output_grads
.
view
({
-
1
,
embed_dim
})
.
sum
(
0
,
false
);
// MatMul2 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
a_layout_t
,
b_layout_n
,
...
...
@@ -337,9 +484,31 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_t,
// b_layout_n,
// k_seq_len,
// q_seq_len,
// head_dim,
// alpha,
// static_cast<const half*>(v_lin_results_ptr),
// lead_dim,
// batch_stride,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// beta,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// attn_batches);
// Matmul2 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
...
...
@@ -357,7 +526,28 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// alpha,
// static_cast<const half*>(output_lin_grads.data_ptr()),
// head_dim*attn_batches,
// head_dim,
// static_cast<const half*>(dropout_results.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// v_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
...
...
@@ -372,6 +562,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
// TODO (OK)
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_n
,
...
...
@@ -385,13 +576,35 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
beta
,
beta
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_n,
// head_dim,
// q_seq_len,
// k_seq_len,
// scale,
// k_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// q_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Matmul1 Dgrad2
// TODO (OK)
gemm_switch_fp32accum
(
state
,
a_layout_n
,
b_layout_t
,
...
...
@@ -408,10 +621,32 @@ std::vector<torch::Tensor> bwd_cuda(
beta
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// gemm_switch_fp32accum( state,
// a_layout_n,
// b_layout_t,
// head_dim,
// k_seq_len,
// q_seq_len,
// scale,
// q_lin_results_ptr,
// lead_dim,
// batch_stride,
// static_cast<half*>(matmul2_grads.data_ptr()),
// k_seq_len,
// k_seq_len*q_seq_len,
// beta,
// k_lin_grads_ptr,
// lead_dim,
// batch_stride,
// attn_batches);
// Input Linear Dgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_N
,
embed_dim
,
...
...
@@ -419,22 +654,47 @@ std::vector<torch::Tensor> bwd_cuda(
output_lin_dim
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
input_weights
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
//static_cast<const void*>(q_lin_grads_ptr),
CUDA_R_16F
,
static_cast
<
const
void
*>
(
input_lin_output_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
static_cast
<
void
*>
(
input_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_N,
// embed_dim,
// batches,
// output_lin_dim,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(input_weights.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(input_lin_output_grads.data_ptr()),
// //static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// //CUBLAS_GEMM_ALGO10_TENSOR_OP));
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad
THCublasCheck
(
cublasGemmEx
(
handle
,
// TODO (OK)
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
embed_dim
,
...
...
@@ -442,20 +702,45 @@ std::vector<torch::Tensor> bwd_cuda(
batches
,
static_cast
<
const
void
*>
(
&
alpha
),
static_cast
<
const
void
*>
(
inputs
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
const
void
*>
(
q_lin_grads_ptr
),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
output_lin_dim
,
static_cast
<
const
void
*>
(
&
beta
),
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
CUDA_R_16F
,
rocblas_datatype_f16_r
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
rocblas_datatype_f16_r
,
embed_dim
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
rocblas_datatype_f32_r
,
algo
,
solution_index
,
flags
));
// THCublasCheck(cublasGemmEx(handle,
// CUBLAS_OP_N,
// CUBLAS_OP_T,
// embed_dim,
// output_lin_dim,
// batches,
// static_cast<const void*>(&alpha),
// static_cast<const void*>(inputs.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// static_cast<const void*>(q_lin_grads_ptr),
// CUDA_R_16F,
// output_lin_dim,
// static_cast<const void*>(&beta),
// static_cast<void*>(input_weight_grads.data_ptr()),
// CUDA_R_16F,
// embed_dim,
// CUDA_R_32F,
// CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto
input_bias_grads
=
input_lin_output_grads
.
view
({
-
1
,
output_lin_dim
}).
sum
(
0
,
false
);
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
// TODO (OK)
// THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return
{
input_grads
,
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
61416180
...
...
@@ -124,13 +124,13 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr
,
c_type
,
output_lin_dim
,
q_lin_results_ptr
,
d_type
,
output_lin_dim
,
q_lin_results_ptr
,
d_type
,
output_lin_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
...
...
@@ -150,9 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
...
...
@@ -215,9 +215,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
...
...
@@ -238,13 +238,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// End-of-block Dropout-Add
...
...
@@ -372,13 +372,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
...
...
@@ -398,13 +398,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
...
...
@@ -424,9 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
...
...
@@ -447,9 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -489,9 +489,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Matmul1 Dgrad2
...
...
@@ -512,7 +512,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
...
...
@@ -536,13 +536,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// Input Linear Wgrad
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
...
...
@@ -563,13 +563,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment