Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
61416180
Commit
61416180
authored
Oct 28, 2021
by
hubertlu-tw
Browse files
Hipify self_multihead_attn_bias
Fix some spacing
parent
8bdbb502
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
384 additions
and
99 deletions
+384
-99
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
...trib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
+4
-4
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+334
-49
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+46
-46
No files found.
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cpp.cpp
View file @
61416180
...
...
@@ -3,7 +3,7 @@
namespace
multihead_attn
{
namespace
self_bias
{
namespace
c
u
blas_gemmex
{
namespace
ro
cblas_gemm
_
ex
{
std
::
vector
<
torch
::
Tensor
>
fwd_cuda
(
bool
use_time_mask
,
...
...
@@ -128,12 +128,12 @@ std::vector<torch::Tensor> bwd(
);
}
}
// end namespace c
u
blas_gemmex
}
// end namespace
ro
cblas_gemm
_
ex
}
// end namespace self
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
c
u
blas_gemmex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
m
.
def
(
"forward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemm
_
ex
::
fwd
,
"Self Multihead Attention with Bias -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
self_bias
::
ro
cblas_gemm
_
ex
::
bwd
,
"Self Multihead Attention with Bias -- Backward."
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
61416180
This diff is collapsed.
Click to expand it.
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
61416180
...
...
@@ -124,13 +124,13 @@ std::vector<torch::Tensor> fwd_cuda(
q_lin_results_ptr
,
c_type
,
output_lin_dim
,
q_lin_results_ptr
,
d_type
,
output_lin_dim
,
q_lin_results_ptr
,
d_type
,
output_lin_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum
(
state
,
...
...
@@ -150,9 +150,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
softmax_results_ptr
),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Padded Softmax
...
...
@@ -215,9 +215,9 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
static_cast
<
half
*>
(
matmul2_results
.
data_ptr
()),
head_dim
*
attn_batches
,
head_dim
,
attn_batches
);
// Output Linear
...
...
@@ -238,13 +238,13 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_results
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// End-of-block Dropout-Add
...
...
@@ -372,13 +372,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
output_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// Output Linear Wgrad
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
...
...
@@ -398,13 +398,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
output_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// MatMul2 Dgrad1
gemm_switch_fp32accum
(
state
,
...
...
@@ -424,9 +424,9 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
k_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
// Matmul2 Dgrad2
...
...
@@ -447,9 +447,9 @@ std::vector<torch::Tensor> bwd_cuda(
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
v_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
...
...
@@ -489,9 +489,9 @@ std::vector<torch::Tensor> bwd_cuda(
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
q_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
// Matmul1 Dgrad2
...
...
@@ -512,7 +512,7 @@ std::vector<torch::Tensor> bwd_cuda(
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
k_lin_grads_ptr
,
k_lin_grads_ptr
,
lead_dim
,
batch_stride
,
attn_batches
);
...
...
@@ -536,13 +536,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
input_lin_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// Input Linear Wgrad
THCublasCheck
(
rocblas_gemm_ex
(
handle
,
...
...
@@ -563,13 +563,13 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
c_type
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
static_cast
<
void
*>
(
input_weight_grads
.
data_ptr
()),
d_type
,
embed_dim
,
compute_type
,
algo
,
solution_index
,
flags
));
solution_index
,
flags
));
// Fused Layer Norm Bwd with Residual Add
HostLayerNormGradient
<
half
,
float
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment