Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d8d07b0e
Commit
d8d07b0e
authored
May 13, 2022
by
Sze-qq
Committed by
binmakeswell
May 17, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp code style (#952)
parent
fa43bb21
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
132 additions
and
96 deletions
+132
-96
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
...ssalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
+132
-96
No files found.
colossalai/kernel/cuda_native/csrc/multihead_attention_1d.cpp
View file @
d8d07b0e
...
...
@@ -10,8 +10,9 @@
#include "kernels.h"
template
<
typename
T
>
MultiHeadAttention
<
T
>::
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_size
,
int
num_heads
,
MultiHeadAttention
<
T
>::
MultiHeadAttention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_size
,
int
num_heads
,
float
attn_prob_dropout_ratio
,
float
hidden_output_dropout_ratio
,
bool
pre_or_postLayerNorm
)
...
...
@@ -22,18 +23,22 @@ MultiHeadAttention<T>::MultiHeadAttention(int layer_id, int max_batch_tokens, in
_heads
(
num_heads
),
_training
(
true
),
_pre_or_postLayerNorm
(
pre_or_postLayerNorm
),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
3
*
hidden_size
,
hidden_size
)),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
hidden_size
,
hidden_size
)),
_attn_ln
(
typename
Normalize_Layer
<
T
>::
Config
(
hidden_size
,
false
),
_max_batch_tokens
),
_qkv_linear
(
typename
FeedForward
<
T
>::
Config
(
3
*
hidden_size
,
hidden_size
)),
_attn_out_linear
(
typename
FeedForward
<
T
>::
Config
(
hidden_size
,
hidden_size
)),
_attn_ln
(
typename
Normalize_Layer
<
T
>::
Config
(
hidden_size
,
false
),
_max_batch_tokens
),
_softmax
(
typename
Softmax
<
T
>::
Config
(
num_heads
)),
_attn_prob_dropout
(
typename
Dropout
<
T
>::
Config
(
attn_prob_dropout_ratio
),
_max_batch_tokens
*
_heads
*
_max_seq_len
),
_attn_dropout
(
typename
Dropout
<
T
>::
Config
(
hidden_output_dropout_ratio
),
_max_batch_tokens
*
_hidden_size
),
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
((
T
(
1.0
)
/
T
(
sqrt
(
_hidden_size
/
_heads
))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
)),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
))
{
_attn_scores
(
typename
StridedBatchGemm
<
T
>::
Config
(
(
T
(
1.0
)
/
T
(
sqrt
(
_hidden_size
/
_heads
))),
T
(
0.0
),
CUBLAS_OP_T
,
CUBLAS_OP_N
)),
_attn_context
(
typename
StridedBatchGemm
<
T
>::
Config
(
T
(
1.0
),
T
(
0.0
),
CUBLAS_OP_N
,
CUBLAS_OP_N
))
{
assert
(
_hidden_size
%
_heads
==
0
);
}
...
...
@@ -43,43 +48,52 @@ MultiHeadAttention<T>::~MultiHeadAttention() {
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
void
MultiHeadAttention
<
T
>::
attn_layer_fw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
output_ptr
,
T
*
buffer
)
{
T
*
q_tf_ptr
=
_qkv_ptr
;
T
*
k_tf_ptr
=
q_tf_ptr
+
_batch_dim
/
pg_size
;
T
*
v_tf_ptr
=
k_tf_ptr
+
_batch_dim
/
pg_size
;
if
(
_pre_or_postLayerNorm
)
{
_attn_ln
.
Forward
(
_gemmQKV_inp_ptr
,
input_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
_attn_ln
.
Forward
(
_gemmQKV_inp_ptr
,
input_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
}
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
_qkv_linear
.
reset_size
(
3
*
_hidden_size
/
pg_size
,
_hidden_size
);
_qkv_linear
.
Forward
(
_batch_tokens
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
buffer
,
_cublasHandle
);
_qkv_linear
.
Forward
(
_batch_tokens
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
buffer
,
_cublasHandle
);
launch_bias_add_transform_20314
<
T
>
(
q_tf_ptr
,
buffer
,
_attn_qkvb_ptr
,
_batch_size
,
_seq_len
,
3
,
_heads
/
pg_size
,
_hidden_size
/
_heads
,
_stream
);
launch_bias_add_transform_20314
<
T
>
(
q_tf_ptr
,
buffer
,
_attn_qkvb_ptr
,
_batch_size
,
_seq_len
,
3
,
_heads
/
pg_size
,
_hidden_size
/
_heads
,
_stream
);
// attention scores, q*k
_attn_scores
.
Forward
(
_batch_heads
,
_soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
_attn_scores
.
Forward
(
_batch_heads
,
_soft_out_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
);
// Softmax + Mask
_softmax
.
reset_size
(
_heads
/
pg_size
);
_softmax
.
Forward
(
_soft_out_ptr
,
input_mask_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
,
true
);
_softmax
.
Forward
(
_soft_out_ptr
,
input_mask_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
,
true
);
// attn prob dropout.
_attn_prob_dropout
.
dropout
(
_ctx_bufB_ptr
,
_soft_out_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
_attn_prob_dropout
.
dropout
(
_ctx_bufB_ptr
,
_soft_out_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
// attention context, score * v
_attn_context
.
Forward
(
_batch_heads
,
buffer
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
);
_attn_context
.
Forward
(
_batch_heads
,
buffer
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
);
// [b, nh, s, ad] -> [b, s, nh, ad]
launch_transform4d_0213
<
T
>
(
_attn_o_inp_ptr
,
buffer
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
1
,
_stream
);
launch_transform4d_0213
<
T
>
(
_attn_o_inp_ptr
,
buffer
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
1
,
_stream
);
_attn_out_linear
.
reset_size
(
_hidden_size
,
_hidden_size
/
pg_size
);
_attn_out_linear
.
Forward
(
_batch_tokens
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
output_ptr
,
_cublasHandle
);
_attn_out_linear
.
Forward
(
_batch_tokens
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
output_ptr
,
_cublasHandle
);
// allreduce
if
(
pg
==
c10
::
detail
::
UniqueVoidPtr
()
||
pg
->
getSize
()
==
1
)
{
...
...
@@ -88,24 +102,27 @@ void MultiHeadAttention<T>::attn_layer_fw(const T *input_ptr, const T *input_mas
if
(
typeid
(
T
)
!=
typeid
(
float
))
{
data_type
=
torch
::
kHalf
;
}
auto
output_tensor
=
torch
::
from_blob
(
output_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
TensorOptions
(
torch
::
kCUDA
).
dtype
(
data_type
));
auto
output_tensor
=
torch
::
from_blob
(
output_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
TensorOptions
(
torch
::
kCUDA
).
dtype
(
data_type
));
std
::
vector
<
torch
::
Tensor
>
allreduce_tensors
=
{
output_tensor
};
auto
work
=
pg
->
allreduce
(
allreduce_tensors
,
c10d
::
AllreduceOptions
());
work
->
wait
();
}
_attn_dropout
.
bias_dropout_residual
(
output_ptr
,
output_ptr
,
input_ptr
,
_attn_ob_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
_attn_dropout
.
bias_dropout_residual
(
output_ptr
,
output_ptr
,
input_ptr
,
_attn_ob_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
if
(
!
_pre_or_postLayerNorm
)
{
// in-place ln since ln-input will not be used in post-ln mode
_attn_ln
.
Forward
(
output_ptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
_attn_ln
.
Forward
(
output_ptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
_stream
);
}
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
)
{
void
MultiHeadAttention
<
T
>::
Forward
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
T
*
out_ptr
)
{
_stream
=
Context
::
Instance
().
get_stream
();
_cublasHandle
=
Context
::
Instance
().
get_cublashandle
();
T
*
attn_buffer
=
_shared_mem_ptr
;
// 3 * _batch_dim
...
...
@@ -114,8 +131,11 @@ void MultiHeadAttention<T>::Forward(const T *input_ptr, const T *input_mask_ptr,
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_ptr
,
T
*
buffer
)
{
void
MultiHeadAttention
<
T
>::
attn_layer_bw
(
const
T
*
input_ptr
,
const
T
*
input_mask_ptr
,
const
T
*
output_ptr
,
const
T
*
grad_output_ptr
,
T
*
grad_input_ptr
,
T
*
buffer
)
{
cudaStream_t
streams
[
2
]
=
{
_stream
,
_stream
};
const
T
*
q_tf_ptr
=
_qkv_ptr
;
...
...
@@ -137,45 +157,57 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
// batch_size * head_num * seq_len * seq_len);
if
(
_pre_or_postLayerNorm
)
{
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_output_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_output_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
}
else
{
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_residual_ptr
,
grad_output_ptr
,
nullptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_residual_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_residual_ptr
,
grad_output_ptr
,
nullptr
,
output_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
_attn_dropout
.
d_bias_dropout_residual
(
grad_input_ptr
,
_grad_attn_ob_ptr
,
grad_residual_ptr
,
_batch_tokens
,
_hidden_size
,
_stream
);
}
// bw of output project
_attn_out_linear
.
reset_size
(
_hidden_size
,
_hidden_size
/
pg_size
);
_attn_out_linear
.
Backward
(
_batch_tokens
,
grad_input_ptr
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
_grad_attn_ow_ptr
,
_grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
false
);
launch_transform_0213
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
_stream
);
_attn_out_linear
.
Backward
(
_batch_tokens
,
grad_input_ptr
,
_attn_o_inp_ptr
,
_attn_ow_ptr
,
_grad_attn_ow_ptr
,
_grad_attn_ob_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
false
);
launch_transform_0213
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
_stream
);
// bw of score * v
_attn_context
.
Backward
(
_batch_heads
,
grad_input_ptr
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
2
*
_batch_dim
/
pg_size
,
grad_softmax_ptr
);
_attn_context
.
Backward
(
_batch_heads
,
grad_input_ptr
,
v_tf_ptr
,
_ctx_bufB_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
2
*
_batch_dim
/
pg_size
,
grad_softmax_ptr
);
_attn_prob_dropout
.
d_dropout
(
grad_softmax_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
_attn_prob_dropout
.
d_dropout
(
grad_softmax_ptr
,
_batch_heads
*
_seq_len
*
_seq_len
,
_stream
);
_softmax
.
reset_size
(
_heads
/
pg_size
);
_softmax
.
Backward
(
grad_softmax_ptr
,
_soft_out_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
);
_softmax
.
Backward
(
grad_softmax_ptr
,
_soft_out_ptr
,
_batch_size
,
_seq_len
,
_seq_len
,
_stream
);
// bw of q * k
_attn_scores
.
Backward
(
_batch_heads
,
grad_softmax_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
_batch_dim
/
pg_size
,
grad_qkv_5d_ptr
);
_attn_scores
.
Backward
(
_batch_heads
,
grad_softmax_ptr
,
k_tf_ptr
,
q_tf_ptr
,
_cublasHandle
,
grad_qkv_5d_ptr
+
_batch_dim
/
pg_size
,
grad_qkv_5d_ptr
);
// [3, b, nh, s, ad] -> [b, s, 3, h]
launch_transform4d_0213
<
T
>
(
grad_qkv_4d_ptr
,
grad_qkv_5d_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
3
,
_stream
);
launch_transform4d_0213
<
T
>
(
grad_qkv_4d_ptr
,
grad_qkv_5d_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
/
pg_size
,
_heads
/
pg_size
,
3
,
_stream
);
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
const
T
*
gemmQKV_inp_ptr
=
_pre_or_postLayerNorm
?
_gemmQKV_inp_ptr
:
input_ptr
;
_qkv_linear
.
reset_size
(
3
*
_hidden_size
/
pg_size
,
_hidden_size
);
_qkv_linear
.
Backward
(
_batch_tokens
,
grad_qkv_4d_ptr
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
_grad_attn_qkvw_ptr
,
_grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
true
);
_qkv_linear
.
Backward
(
_batch_tokens
,
grad_qkv_4d_ptr
,
gemmQKV_inp_ptr
,
_attn_qkvw_ptr
,
_grad_attn_qkvw_ptr
,
_grad_attn_qkvb_ptr
,
_cublasHandle
,
_stream
,
grad_input_buf_ptr
,
nullptr
,
true
);
// allreduce
if
(
pg
==
c10
::
detail
::
UniqueVoidPtr
()
||
pg
->
getSize
()
==
1
)
{
...
...
@@ -185,7 +217,8 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
data_type
=
torch
::
kHalf
;
}
auto
grad_input_tensor
=
torch
::
from_blob
(
grad_input_buf_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
from_blob
(
grad_input_buf_ptr
,
{
int
(
_batch_size
),
int
(
_seq_len
),
int
(
_hidden_size
)},
torch
::
TensorOptions
(
torch
::
kCUDA
).
dtype
(
data_type
));
std
::
vector
<
torch
::
Tensor
>
allreduce_tensors
=
{
grad_input_tensor
};
auto
work
=
pg
->
allreduce
(
allreduce_tensors
,
c10d
::
AllreduceOptions
());
...
...
@@ -193,19 +226,21 @@ void MultiHeadAttention<T>::attn_layer_bw(const T *input_ptr, const T *input_mas
}
if
(
_pre_or_postLayerNorm
)
{
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_input_ptr
,
grad_input_buf_ptr
,
grad_output_ptr
,
gemmQKV_inp_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
_attn_ln
.
Backward
(
_grad_attn_nw_ptr
,
_grad_attn_nb_ptr
,
grad_input_ptr
,
grad_input_buf_ptr
,
grad_output_ptr
,
gemmQKV_inp_ptr
,
_attn_nw_ptr
,
_attn_nb_ptr
,
_batch_tokens
,
streams
);
}
else
{
// FIXME later
launch_fused_add2
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
grad_residual_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
,
_stream
);
launch_fused_add2
<
T
>
(
grad_input_ptr
,
grad_input_buf_ptr
,
grad_residual_ptr
,
_batch_size
,
_seq_len
,
_hidden_size
,
_stream
);
}
}
template
<
typename
T
>
void
MultiHeadAttention
<
T
>::
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
)
{
void
MultiHeadAttention
<
T
>::
Backward
(
const
T
*
grad_output_ptr
,
const
T
*
input_ptr
,
const
T
*
output_ptr
,
const
T
*
input_mask_ptr
,
T
*
grad_input_ptr
)
{
_stream
=
Context
::
Instance
().
get_stream
();
_cublasHandle
=
Context
::
Instance
().
get_cublashandle
();
T
*
buffer
=
_shared_mem_ptr
;
...
...
@@ -215,7 +250,8 @@ void MultiHeadAttention<T>::Backward(const T *grad_output_ptr, const T *input_pt
4 * _batch_dim + max(3 * _batch_dim,
_batch_size * _head_num * _seq_len * _seq_len);
*/
attn_layer_bw
(
input_ptr
,
input_mask_ptr
,
output_ptr
,
grad_output_ptr
,
grad_input_ptr
,
buffer
);
attn_layer_bw
(
input_ptr
,
input_mask_ptr
,
output_ptr
,
grad_output_ptr
,
grad_input_ptr
,
buffer
);
}
template
<
typename
T
>
...
...
@@ -233,7 +269,8 @@ template class MultiHeadAttention<__half>;
// x is torch::Tensor
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
...
...
@@ -241,15 +278,17 @@ template class MultiHeadAttention<__half>;
static
std
::
unordered_map
<
int
,
std
::
shared_ptr
<
void
>>
s_multihead_attention
;
template
<
typename
T
>
int
create_multihead_attention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_dim
,
int
num_heads
,
float
attn_prob_dropout_ratio
,
float
hidden_dropout_ratio
,
bool
pre_or_postLayerNorm
,
int
create_multihead_attention
(
int
layer_id
,
int
max_batch_tokens
,
int
max_seq_len
,
int
hidden_dim
,
int
num_heads
,
float
attn_prob_dropout_ratio
,
float
hidden_dropout_ratio
,
bool
pre_or_postLayerNorm
,
c10
::
intrusive_ptr
<
c10d
::
ProcessGroup
>
pg_
)
{
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
Context
::
Instance
().
set_stream
(
stream
);
auto
layer
=
std
::
make_shared
<
MultiHeadAttention
<
T
>>
(
layer_id
,
max_batch_tokens
,
max_seq_len
,
hidden_dim
,
num_heads
,
attn_prob_dropout_ratio
,
hidden_dropout_ratio
,
pre_or_postLayerNorm
);
layer_id
,
max_batch_tokens
,
max_seq_len
,
hidden_dim
,
num_heads
,
attn_prob_dropout_ratio
,
hidden_dropout_ratio
,
pre_or_postLayerNorm
);
layer
->
SetPG
(
pg_
);
...
...
@@ -261,15 +300,12 @@ int create_multihead_attention(int layer_id, int max_batch_tokens, int max_seq_l
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
multihead_attention_fw
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
,
bool
training_mode
,
bool
prelayernorm
)
{
std
::
vector
<
torch
::
Tensor
>
multihead_attention_fw
(
int
layer_id
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
,
bool
training_mode
,
bool
prelayernorm
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input_mask
);
...
...
@@ -280,7 +316,8 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
T
*
out_ptr
=
(
T
*
)
output
.
data_ptr
();
std
::
shared_ptr
<
MultiHeadAttention
<
T
>>
layer
=
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
layer
->
set_cur_batch_shape
(
input
.
size
(
0
),
input
.
size
(
1
));
layer
->
SetTrainingMode
(
training_mode
);
...
...
@@ -297,17 +334,13 @@ std::vector<torch::Tensor> multihead_attention_fw(int layer_id, const torch::Ten
}
template
<
typename
T
>
std
::
vector
<
torch
::
Tensor
>
multihead_attention_bw
(
int
layer_id
,
const
torch
::
Tensor
&
grad_dec_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
)
{
std
::
vector
<
torch
::
Tensor
>
multihead_attention_bw
(
int
layer_id
,
const
torch
::
Tensor
&
grad_dec_output
,
const
torch
::
Tensor
&
output
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
input_mask
,
const
torch
::
Tensor
&
in_proj_weight
,
const
torch
::
Tensor
&
in_proj_bias
,
const
torch
::
Tensor
&
out_proj_weight
,
const
torch
::
Tensor
&
out_proj_bias
,
const
torch
::
Tensor
&
norm_weight
,
const
torch
::
Tensor
&
norm_bias
)
{
auto
g_output
=
grad_dec_output
.
contiguous
();
CHECK_INPUT
(
g_output
);
CHECK_INPUT
(
output
);
...
...
@@ -332,7 +365,8 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
T
*
grad_input_ptr
=
(
T
*
)
grad_input
.
data_ptr
();
std
::
shared_ptr
<
MultiHeadAttention
<
T
>>
layer
=
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
std
::
static_pointer_cast
<
MultiHeadAttention
<
T
>>
(
s_multihead_attention
[
layer_id
]);
layer
->
set_cur_batch_shape
(
g_output
.
size
(
0
),
g_output
.
size
(
1
));
layer
->
_grad_attn_qkvw_ptr
=
(
T
*
)
grad_in_proj_weight
.
data_ptr
();
...
...
@@ -342,10 +376,12 @@ std::vector<torch::Tensor> multihead_attention_bw(int layer_id,
layer
->
_grad_attn_nw_ptr
=
(
T
*
)
grad_norm_weight
.
data_ptr
();
layer
->
_grad_attn_nb_ptr
=
(
T
*
)
grad_norm_bias
.
data_ptr
();
layer
->
Backward
(
grad_dec_output_ptr
,
input_ptr
,
output_ptr
,
input_mask_ptr
,
grad_input_ptr
);
layer
->
Backward
(
grad_dec_output_ptr
,
input_ptr
,
output_ptr
,
input_mask_ptr
,
grad_input_ptr
);
return
{
grad_input
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
};
return
{
grad_input
,
grad_in_proj_weight
,
grad_in_proj_bias
,
grad_out_proj_weight
,
grad_out_proj_bias
,
grad_norm_weight
,
grad_norm_bias
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment