Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
5754fa7a
Unverified
Commit
5754fa7a
authored
May 29, 2020
by
Kevin Stephano
Committed by
GitHub
May 29, 2020
Browse files
Fixes to Multihead Attention with LayerNorm and Dropout-Add (#860)
parent
6c2babf9
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
187 additions
and
70 deletions
+187
-70
apex/contrib/csrc/multihead_attn/dropout.h
apex/contrib/csrc/multihead_attn/dropout.h
+13
-13
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
...contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
+6
-6
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
...src/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
+20
-19
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
+6
-8
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
.../csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
+20
-20
apex/contrib/examples/multihead_attn/func_test_multihead_attn.py
...ntrib/examples/multihead_attn/func_test_multihead_attn.py
+108
-0
apex/contrib/multihead_attn/encdec_multihead_attn.py
apex/contrib/multihead_attn/encdec_multihead_attn.py
+7
-2
apex/contrib/multihead_attn/encdec_multihead_attn_func.py
apex/contrib/multihead_attn/encdec_multihead_attn_func.py
+1
-1
apex/contrib/multihead_attn/self_multihead_attn.py
apex/contrib/multihead_attn/self_multihead_attn.py
+6
-1
No files found.
apex/contrib/csrc/multihead_attn/dropout.h
View file @
5754fa7a
...
...
@@ -42,10 +42,11 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
linearIndex
+=
gridDim
.
x
*
blockDim
.
x
*
UNROLL
)
{
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
rand
.
x
=
rand
.
x
<
p
;
rand
.
y
=
rand
.
y
<
p
;
rand
.
z
=
rand
.
z
<
p
;
rand
.
w
=
rand
.
w
<
p
;
rand
.
x
=
rand
.
x
<=
p
;
rand
.
y
=
rand
.
y
<=
p
;
rand
.
z
=
rand
.
z
<=
p
;
rand
.
w
=
rand
.
w
<=
p
;
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
...
...
@@ -55,7 +56,7 @@ __global__ void apex_fused_dropout_kernel(scalar_t const *inputs,
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
src
[
ii
]
*
static_cast
<
scalar_t
>
(
(
&
rand
.
x
)[
ii
]
*
pinv
)
;
outputs
[
li
]
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
...
...
@@ -94,10 +95,10 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
float4
rand
=
curand_uniform4
(
&
state
);
scalar_t
src
[
UNROLL
];
scalar_t
add_src
[
UNROLL
];
rand
.
x
=
rand
.
x
<
p
;
rand
.
y
=
rand
.
y
<
p
;
rand
.
z
=
rand
.
z
<
p
;
rand
.
w
=
rand
.
w
<
p
;
rand
.
x
=
rand
.
x
<
=
p
;
rand
.
y
=
rand
.
y
<
=
p
;
rand
.
z
=
rand
.
z
<
=
p
;
rand
.
w
=
rand
.
w
<
=
p
;
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
...
...
@@ -108,9 +109,8 @@ __global__ void apex_dropout_add_kernel(scalar_t const *inputs,
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
accscalar_t
int1
=
static_cast
<
accscalar_t
>
((
&
rand
.
x
)[
ii
])
*
static_cast
<
accscalar_t
>
(
src
[
ii
]);
accscalar_t
int2
=
int1
*
static_cast
<
accscalar_t
>
(
pinv
);
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
static_cast
<
accscalar_t
>
(
add_src
[
ii
])
+
int2
);
accscalar_t
int1
=
src
[
ii
]
*
(
&
rand
.
x
)[
ii
]
*
pinv
;
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
static_cast
<
accscalar_t
>
(
add_src
[
ii
])
+
int1
);
mask
[
li
]
=
(
uint8_t
)(
&
rand
.
x
)[
ii
];
}
}
...
...
@@ -182,7 +182,7 @@ __global__ void apex_masked_scale_kernel(scalar_t const *inputs,
for
(
int
ii
=
0
;
ii
<
UNROLL
;
ii
++
)
{
IndexType
li
=
linearIndex
+
blockDim
.
x
*
gridDim
.
x
*
ii
;
if
(
li
<
totalElements
)
{
outputs
[
li
]
=
static_cast
<
scalar_t
>
(
src
[
ii
]
*
static_cast
<
scalar_t
>
(
scale
))
*
msk
[
ii
];
outputs
[
li
]
=
static_cast
<
acc
scalar_t
>
(
src
[
ii
]
)
*
scale
*
static_cast
<
acc
scalar_t
>
(
msk
[
ii
]
)
;
}
}
}
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
View file @
5754fa7a
...
...
@@ -182,9 +182,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
h
alf
*>
(
dropout_results
.
data_ptr
()),
apex_fused_dropout_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
...
...
@@ -397,9 +397,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
h
alf
*>
(
matmul2_grads
.
data_ptr
()),
apex_masked_scale_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
...
...
apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
View file @
5754fa7a
...
...
@@ -204,9 +204,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
h
alf
*>
(
dropout_results
.
data_ptr
()),
apex_fused_dropout_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
...
...
@@ -257,18 +257,18 @@ std::vector<torch::Tensor> fwd_cuda(
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
h
alf
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
h
alf
*>
(
outputs
.
data_ptr
()),
apex_dropout_add_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
outputs
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens_q
,
(
1.0
f
-
dropout_prob
));
}
else
{
apex_add_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
h
alf
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
h
alf
*>
(
outputs
.
data_ptr
()),
apex_add_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
const
*>
(
inputs_q
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
outputs
.
data_ptr
()),
total_tokens_q
);
}
...
...
@@ -347,6 +347,7 @@ std::vector<torch::Tensor> bwd_cuda(
torch
::
Tensor
input_weight_kv_grads
=
torch
::
empty_like
(
input_weights_kv
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
at
::
Tensor
dropout_add_grads
=
torch
::
empty_like
(
output_grads
);
at
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
at
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
at
::
Tensor
input_lin_q_output_grads
=
torch
::
empty_like
(
input_lin_q_results
);
...
...
@@ -369,9 +370,9 @@ std::vector<torch::Tensor> bwd_cuda(
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
apex_masked_scale_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
h
alf
*>
(
output
_grads
.
data_ptr
()),
apex_masked_scale_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
dropout_add
_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens_q
,
(
1.0
/
(
1.0
-
dropout_prob
)));
...
...
@@ -387,7 +388,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output
_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add
_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
...
...
@@ -408,7 +409,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output
_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add
_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
...
...
@@ -459,9 +460,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
h
alf
*>
(
matmul2_grads
.
data_ptr
()),
apex_masked_scale_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
View file @
5754fa7a
...
...
@@ -153,9 +153,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
h
alf
*>
(
dropout_results
.
data_ptr
()),
apex_fused_dropout_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
...
...
@@ -200,7 +200,6 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
...
...
@@ -357,9 +356,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
h
alf
*>
(
matmul2_grads
.
data_ptr
()),
apex_masked_scale_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
...
...
@@ -434,7 +433,6 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO10_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// Input Linear Wgrad
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
View file @
5754fa7a
...
...
@@ -176,9 +176,9 @@ std::vector<torch::Tensor> fwd_cuda(
assert
(
softmax_success
);
if
(
is_training
)
{
apex_fused_dropout_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
h
alf
*>
(
dropout_results
.
data_ptr
()),
apex_fused_dropout_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
softmax_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
dropout_results
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
f
-
dropout_prob
));
...
...
@@ -224,23 +224,22 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_16F
,
embed_dim
,
CUDA_R_32F
,
//CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
// End-of-block Dropout-Add
if
(
is_training
)
{
apex_dropout_add_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
h
alf
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
h
alf
*>
(
outputs
.
data_ptr
()),
apex_dropout_add_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
outputs
.
data_ptr
()),
static_cast
<
uint8_t
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens
,
(
1.0
f
-
dropout_prob
));
}
else
{
apex_add_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
h
alf
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
h
alf
*>
(
outputs
.
data_ptr
()),
apex_add_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
output_lin_results
.
data_ptr
()),
static_cast
<
at
::
H
alf
const
*>
(
inputs
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
outputs
.
data_ptr
()),
total_tokens
);
}
...
...
@@ -309,6 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
torch
::
Tensor
input_weight_grads
=
torch
::
empty_like
(
input_weights
);
torch
::
Tensor
output_weight_grads
=
torch
::
empty_like
(
output_weights
);
// Intermediate Tensor Allocations
torch
::
Tensor
dropout_add_grads
=
torch
::
empty_like
(
output_grads
);
torch
::
Tensor
output_lin_grads
=
torch
::
empty_like
(
matmul2_results
);
torch
::
Tensor
matmul2_grads
=
torch
::
empty_like
(
dropout_results
);
torch
::
Tensor
input_lin_output_grads
=
torch
::
empty_like
(
input_lin_results
);
...
...
@@ -330,9 +330,9 @@ std::vector<torch::Tensor> bwd_cuda(
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
// Dropout Add Backward
apex_masked_scale_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
h
alf
*>
(
output
_grads
.
data_ptr
()),
apex_masked_scale_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
output_grads
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
dropout_add
_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_add_mask
.
data_ptr
()),
total_tokens
,
(
1.0
/
(
1.0
-
dropout_prob
)));
...
...
@@ -348,7 +348,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
output_weights
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output
_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add
_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
...
...
@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast
<
const
void
*>
(
matmul2_results
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
output
_grads
.
data_ptr
()),
static_cast
<
const
void
*>
(
dropout_add
_grads
.
data_ptr
()),
CUDA_R_16F
,
embed_dim
,
static_cast
<
const
void
*>
(
&
beta
),
...
...
@@ -420,9 +420,9 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches
);
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda
<
h
alf
,
float
,
uint32_t
>
(
static_cast
<
h
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
h
alf
*>
(
matmul2_grads
.
data_ptr
()),
apex_masked_scale_cuda
<
at
::
H
alf
,
float
,
uint32_t
>
(
static_cast
<
at
::
H
alf
const
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
at
::
H
alf
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
dropout_elems
,
(
1.0
/
(
1.0
-
dropout_prob
)));
...
...
apex/contrib/examples/multihead_attn/func_test_multihead_attn.py
0 → 100644
View file @
5754fa7a
import
torch
import
torch.nn.functional
as
F
import
argparse
from
apex.contrib.multihead_attn
import
SelfMultiheadAttn
from
apex.contrib.multihead_attn
import
EncdecMultiheadAttn
parser
=
argparse
.
ArgumentParser
(
description
=
'Multihead Attention Standalone Test'
)
parser
.
add_argument
(
'--seq-length'
,
default
=
64
,
type
=
int
,
help
=
'Sequence Length of Input'
)
parser
.
add_argument
(
'--num-seqs-start'
,
default
=
5
,
type
=
int
,
help
=
'Start Range of Number of Sequences'
)
parser
.
add_argument
(
'--num-seqs-stop'
,
default
=
80
,
type
=
int
,
help
=
'Stop Range of Number of Sequences'
)
parser
.
add_argument
(
'--num-seqs-inc'
,
default
=
5
,
type
=
int
,
help
=
'Range Increment of Number of Sequences'
)
parser
.
add_argument
(
'--trials'
,
default
=
20
,
type
=
int
,
help
=
'Number of Trials to Execute'
)
parser
.
add_argument
(
'--warmup-trials'
,
default
=
5
,
type
=
int
,
help
=
'Warmup Trials to discard'
)
parser
.
add_argument
(
'--layers'
,
default
=
18
,
type
=
int
,
help
=
'Attention Layers to Execute to Gain CPU/GPU Time Overlap'
)
parser
.
add_argument
(
'--seed-start'
,
default
=
1
,
type
=
int
,
help
=
'Attention Layers to Execute to Gain CPU/GPU Time Overlap'
)
parser
.
add_argument
(
'--seed-end'
,
default
=
100
,
type
=
int
,
help
=
'Attention Layers to Execute to Gain CPU/GPU Time Overlap'
)
parser
.
add_argument
(
'--hidden-dim'
,
default
=
1024
,
type
=
int
,
help
=
'Multihead Attention hidden dimension'
)
parser
.
add_argument
(
'--heads'
,
default
=
16
,
type
=
int
,
help
=
'Number of Multihead Attention heads'
)
parser
.
add_argument
(
'--encdec-attn'
,
action
=
'store_true'
,
help
=
'Use Encoder-Decoder Attention instead of Self Attention.'
)
parser
.
add_argument
(
'--norm-add'
,
action
=
'store_true'
,
help
=
'Include Layer Norm and Dropout-Add in Multihead Attention block.'
)
parser
.
add_argument
(
'--ref'
,
action
=
'store_true'
,
help
=
'Reference implementation in python pytorch.'
)
parser
.
add_argument
(
'--native'
,
action
=
'store_true'
,
help
=
'torch.nn.MultitheadAttention Version.'
)
parser
.
add_argument
(
'--fwd'
,
action
=
'store_true'
,
help
=
'Only execute Fwd Pass.'
)
parser
.
add_argument
(
'--eval'
,
action
=
'store_true'
,
help
=
'Inference only, no backward pass.'
)
args
=
parser
.
parse_args
()
assert
args
.
seq_length
%
64
==
0
,
"Sequence Length should be a multiple of 64!"
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Running on CPU is not supported'
)
torch
.
cuda
.
set_device
(
0
)
dropout_prob
=
0.1
for
seed
in
range
(
args
.
seed_start
,
args
.
seed_end
+
1
)
:
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
ref_layer
=
None
if
args
.
encdec_attn
:
ref_layer
=
EncdecMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
dropout_prob
,
bias
=
False
,
include_norm_add
=
args
.
norm_add
,
impl
=
'default'
)
else
:
ref_layer
=
SelfMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
dropout_prob
,
bias
=
False
,
include_norm_add
=
args
.
norm_add
,
impl
=
'default'
)
ref_layer
.
cuda
()
ref_layer
.
half
()
ref_layer
.
reset_parameters
()
ref_inputs
=
torch
.
randn
(
args
.
seq_length
,
args
.
num_seqs_start
,
args
.
hidden_dim
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
)).
requires_grad_
(
True
)
ref_inputs_kv
=
None
if
args
.
encdec_attn
:
ref_inputs_kv
=
torch
.
randn
(
args
.
seq_length
,
args
.
num_seqs_start
,
args
.
hidden_dim
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
)).
requires_grad_
(
True
)
ref_grads
=
torch
.
randn_like
(
ref_inputs
)
ref_outputs
,
_
=
ref_layer
.
forward
(
ref_inputs
,
ref_inputs_kv
,
ref_inputs_kv
,
key_padding_mask
=
None
,
need_weights
=
False
,
attn_mask
=
None
,
is_training
=
(
not
args
.
eval
))
ref_outputs
.
backward
(
ref_grads
)
torch
.
manual_seed
(
seed
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
seed
)
tst_layer
=
None
if
args
.
encdec_attn
:
tst_layer
=
EncdecMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
dropout_prob
,
bias
=
False
,
include_norm_add
=
args
.
norm_add
,
impl
=
'fast'
)
else
:
tst_layer
=
SelfMultiheadAttn
(
args
.
hidden_dim
,
args
.
heads
,
dropout
=
dropout_prob
,
bias
=
False
,
include_norm_add
=
args
.
norm_add
,
impl
=
'fast'
)
tst_layer
.
cuda
()
tst_layer
.
half
()
tst_layer
.
reset_parameters
()
tst_inputs
=
torch
.
randn
(
args
.
seq_length
,
args
.
num_seqs_start
,
args
.
hidden_dim
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
)).
requires_grad_
(
True
)
tst_inputs_kv
=
None
if
args
.
encdec_attn
:
tst_inputs_kv
=
torch
.
randn
(
args
.
seq_length
,
args
.
num_seqs_start
,
args
.
hidden_dim
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
)).
requires_grad_
(
True
)
assert
torch
.
equal
(
ref_inputs
,
tst_inputs
),
"ERROR: Inputs are different!"
tst_grads
=
torch
.
randn_like
(
tst_inputs
)
tst_outputs
,
_
=
tst_layer
.
forward
(
tst_inputs
,
tst_inputs_kv
,
tst_inputs_kv
,
key_padding_mask
=
None
,
need_weights
=
False
,
attn_mask
=
None
,
is_training
=
(
not
args
.
eval
))
tst_outputs
.
backward
(
tst_grads
)
fwd_close
=
torch
.
equal
(
ref_outputs
,
tst_outputs
)
bwd_close
=
torch
.
equal
(
ref_inputs
.
grad
,
tst_inputs
.
grad
)
diff_fwd
=
ref_outputs
-
tst_outputs
diff_cnt_fwd
=
diff_fwd
.
ne
(
0.0
).
sum
()
diff_accum_fwd
=
diff_fwd
.
abs
().
sum
()
diff_bwd
=
ref_inputs
.
grad
-
tst_inputs
.
grad
diff_cnt_bwd
=
diff_bwd
.
ne
(
0.0
).
sum
()
diff_accum_bwd
=
diff_bwd
.
abs
().
sum
()
print
(
">>> Seed: "
,
seed
,
fwd_close
,
diff_cnt_fwd
.
item
(),
diff_accum_fwd
.
item
(),
bwd_close
,
diff_cnt_bwd
.
item
(),
diff_accum_bwd
.
item
())
apex/contrib/multihead_attn/encdec_multihead_attn.py
View file @
5754fa7a
...
...
@@ -6,7 +6,12 @@ import torch.nn.functional as F
from
.encdec_multihead_attn_func
import
encdec_attn_func
from
.fast_encdec_multihead_attn_func
import
fast_encdec_attn_func
from
.fast_encdec_multihead_attn_norm_add_func
import
fast_encdec_attn_norm_add_func
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
if
hasattr
(
torch
.
_C
,
'_jit_set_profiling_executor'
)
:
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
if
hasattr
(
torch
.
_C
,
'_jit_set_profiling_mode'
)
:
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
@
torch
.
jit
.
script
def
jit_dropout_add
(
x
,
residual
,
prob
,
is_training
):
...
...
@@ -57,9 +62,9 @@ class EncdecMultiheadAttn(nn.Module):
self
.
register_parameter
(
'lyr_norm_beta_weights'
,
None
)
self
.
lyr_nrm_gamma_weights
=
None
self
.
lyr_nrm_beta_weights
=
None
self
.
lyr_nrm
=
torch
.
nn
.
LayerNorm
(
embed_dim
)
self
.
lyr_nrm
=
Fused
LayerNorm
(
embed_dim
)
self
.
reset_parameters
()
if
self
.
include_norm_add
:
if
impl
==
'fast'
:
self
.
attn_func
=
fast_encdec_attn_norm_add_func
elif
impl
==
'default'
:
self
.
attn_func
=
encdec_attn_func
...
...
apex/contrib/multihead_attn/encdec_multihead_attn_func.py
View file @
5754fa7a
...
...
@@ -203,7 +203,7 @@ class EncdecAttnFunc(torch.autograd.Function):
values_grads
=
torch
.
bmm
(
dropout_results
.
transpose
(
1
,
2
),
output_lin_grads
,
out
=
values_grads
.
transpose
(
0
,
1
))
# Mask and Scaling for Dropout (not a publically documented op)
dropout_grads
=
torch
.
_masked_scale
(
matmul2_dgrad1
,
dropout_mask
,
dropout_prob_t
[
0
])
dropout_grads
=
torch
.
_masked_scale
(
matmul2_dgrad1
,
dropout_mask
,
1.0
/
(
1.0
-
dropout_prob_t
[
0
])
)
# Softmax Grad (not a publically documented op)
softmax_grads
=
torch
.
_softmax_backward_data
(
dropout_grads
,
softmax_results
,
-
1
,
softmax_results
)
...
...
apex/contrib/multihead_attn/self_multihead_attn.py
View file @
5754fa7a
...
...
@@ -6,7 +6,12 @@ import torch.nn.functional as F
from
.self_multihead_attn_func
import
self_attn_func
from
.fast_self_multihead_attn_func
import
fast_self_attn_func
from
.fast_self_multihead_attn_norm_add_func
import
fast_self_attn_norm_add_func
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
if
hasattr
(
torch
.
_C
,
'_jit_set_profiling_executor'
)
:
torch
.
_C
.
_jit_set_profiling_executor
(
False
)
if
hasattr
(
torch
.
_C
,
'_jit_set_profiling_mode'
)
:
torch
.
_C
.
_jit_set_profiling_mode
(
False
)
@
torch
.
jit
.
script
def
jit_dropout_add
(
x
,
residual
,
prob
,
is_training
):
...
...
@@ -75,7 +80,7 @@ class SelfMultiheadAttn(nn.Module):
self
.
register_parameter
(
'lyr_norm_beta_weights'
,
None
)
self
.
lyr_nrm_gamma_weights
=
None
self
.
lyr_nrm_beta_weights
=
None
self
.
lyr_nrm
=
torch
.
nn
.
LayerNorm
(
embed_dim
)
self
.
lyr_nrm
=
Fused
LayerNorm
(
embed_dim
)
self
.
reset_parameters
()
if
self
.
include_norm_add
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment