Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
3fe10b55
Unverified
Commit
3fe10b55
authored
Dec 03, 2020
by
Burc Eryilmaz
Committed by
GitHub
Dec 03, 2020
Browse files
Seryilmaz/fused dropout softmax (#985)
* fuse dropout into softmax in fprop for additive mask case
parent
6c186b3b
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1019 additions
and
140 deletions
+1019
-140
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
...rc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
+2
-2
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
...ontrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
+4
-4
apex/contrib/csrc/multihead_attn/philox.h
apex/contrib/csrc/multihead_attn/philox.h
+90
-0
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp
...multihead_attn/self_multihead_attn_bias_additive_mask.cpp
+8
-5
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
...ihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
+30
-27
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
...trib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
+2
-2
apex/contrib/csrc/multihead_attn/softmax.h
apex/contrib/csrc/multihead_attn/softmax.h
+720
-68
apex/contrib/multihead_attn/fast_self_multihead_attn_func.py
apex/contrib/multihead_attn/fast_self_multihead_attn_func.py
+86
-32
apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py
...test/multihead_attn/test_fast_self_multihead_attn_bias.py
+77
-0
No files found.
apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
View file @
3fe10b55
...
@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda(
...
@@ -113,7 +113,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
dispatch_masked_scale_softmax_backward
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward
_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
...
@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda(
...
@@ -121,7 +121,7 @@ torch::Tensor bwd_cuda(
1.0
/
(
1.0
-
dropout_prob
),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
attn_batches
*
q_seq_len
,
stream
);
//backward pass is completely in-place
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
}
}
...
...
apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
View file @
3fe10b55
...
@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda(
...
@@ -115,7 +115,7 @@ torch::Tensor bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
if
(
padding_mask
==
nullptr
)
{
if
(
padding_mask
==
nullptr
)
{
dispatch_masked_scale_softmax_backward
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward
_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
...
@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda(
...
@@ -123,9 +123,9 @@ torch::Tensor bwd_cuda(
1.0
/
(
1.0
-
dropout_prob
),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
attn_batches
*
q_seq_len
,
stream
);
}
else
{
}
else
{
dispatch_masked_scale_softmax_backward_masked_out
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward_masked_out
_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
static_cast
<
half
*>
(
output_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
...
@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
...
@@ -135,7 +135,7 @@ torch::Tensor bwd_cuda(
k_seq_len
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
,
heads
);
heads
,
stream
);
}
}
//backward pass is completely in-place
//backward pass is completely in-place
...
...
apex/contrib/csrc/multihead_attn/philox.h
0 → 100644
View file @
3fe10b55
#pragma once
//Philox CUDA.
class
Philox
{
public:
__device__
inline
Philox
(
unsigned
long
long
seed
,
unsigned
long
long
subsequence
,
unsigned
long
long
offset
)
{
key
.
x
=
(
unsigned
int
)
seed
;
key
.
y
=
(
unsigned
int
)(
seed
>>
32
);
counter
=
make_uint4
(
0
,
0
,
0
,
0
);
counter
.
z
=
(
unsigned
int
)(
subsequence
);
counter
.
w
=
(
unsigned
int
)(
subsequence
>>
32
);
STATE
=
0
;
incr_n
(
offset
/
4
);
}
__device__
inline
uint4
operator
()()
{
if
(
STATE
==
0
)
{
uint4
counter_
=
counter
;
uint2
key_
=
key
;
//7-round philox
for
(
int
i
=
0
;
i
<
6
;
i
++
)
{
counter_
=
single_round
(
counter_
,
key_
);
key_
.
x
+=
(
kPhilox10A
);
key_
.
y
+=
(
kPhilox10B
);
}
output
=
single_round
(
counter_
,
key_
);
incr
();
}
//return a float4 directly
//unsigned long ret;
//switch(STATE) {
// case 0: ret = output.x; break;
// case 1: ret = output.y; break;
// case 2: ret = output.z; break;
// case 3: ret = output.w; break;
//}
//STATE = (STATE + 1) % 4;
return
output
;
}
private:
uint4
counter
;
uint4
output
;
uint2
key
;
unsigned
int
STATE
;
__device__
inline
void
incr_n
(
unsigned
long
long
n
)
{
unsigned
int
nlo
=
(
unsigned
int
)(
n
);
unsigned
int
nhi
=
(
unsigned
int
)(
n
>>
32
);
counter
.
x
+=
nlo
;
if
(
counter
.
x
<
nlo
)
nhi
++
;
counter
.
y
+=
nhi
;
if
(
nhi
<=
counter
.
y
)
return
;
if
(
++
counter
.
z
)
return
;
++
counter
.
w
;
}
__device__
inline
void
incr
()
{
if
(
++
counter
.
x
)
return
;
if
(
++
counter
.
y
)
return
;
if
(
++
counter
.
z
)
return
;
++
counter
.
w
;
}
__device__
unsigned
int
mulhilo32
(
unsigned
int
a
,
unsigned
int
b
,
unsigned
int
*
result_high
)
{
*
result_high
=
__umulhi
(
a
,
b
);
return
a
*
b
;
}
__device__
inline
uint4
single_round
(
uint4
ctr
,
uint2
key
)
{
unsigned
int
hi0
;
unsigned
int
hi1
;
unsigned
int
lo0
=
mulhilo32
(
kPhiloxSA
,
ctr
.
x
,
&
hi0
);
unsigned
int
lo1
=
mulhilo32
(
kPhiloxSB
,
ctr
.
z
,
&
hi1
);
uint4
ret
=
{
hi1
^
ctr
.
y
^
key
.
x
,
lo1
,
hi0
^
ctr
.
w
^
key
.
y
,
lo0
};
return
ret
;
}
static
const
unsigned
long
kPhilox10A
=
0x9E3779B9
;
static
const
unsigned
long
kPhilox10B
=
0xBB67AE85
;
static
const
unsigned
long
kPhiloxSA
=
0xD2511F53
;
static
const
unsigned
long
kPhiloxSB
=
0xCD9E8D57
;
};
// Inverse of 2^32.
#define M_RAN_INVM32 2.3283064e-10f
__device__
__inline__
float4
uniform4
(
uint4
x
)
{
return
make_float4
(
x
.
x
*
M_RAN_INVM32
,
x
.
y
*
M_RAN_INVM32
,
x
.
z
*
M_RAN_INVM32
,
x
.
w
*
M_RAN_INVM32
);
}
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask.cpp
View file @
3fe10b55
...
@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -24,7 +24,9 @@ std::vector<torch::Tensor> bwd_cuda(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
// torch::Tensor const& softmax_results,
torch
::
Tensor
const
&
bmm1_results
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
input_weights
,
...
@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
...
@@ -60,6 +62,7 @@ std::vector<torch::Tensor> fwd(
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
use_mask
,
"no mask is not supported"
);
if
(
use_mask
)
{
if
(
use_mask
)
{
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
pad_mask
.
dim
()
==
2
,
"expected 2D tensor"
);
...
@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
...
@@ -85,7 +88,8 @@ std::vector<torch::Tensor> bwd(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
bmm1_results
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
input_weights
,
...
@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
...
@@ -97,7 +101,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
matmul2_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
dropout_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_lin_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
inputs
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
AT_ASSERTM
(
input_weights
.
dim
()
==
2
,
"expected 2D tensor"
);
...
@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
...
@@ -107,7 +110,6 @@ std::vector<torch::Tensor> bwd(
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
output_grads
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
matmul2_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
dropout_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
softmax_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_lin_results
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
inputs
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
input_weights
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
...
@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
...
@@ -119,7 +121,8 @@ std::vector<torch::Tensor> bwd(
output_grads
,
output_grads
,
matmul2_results
,
matmul2_results
,
dropout_results
,
dropout_results
,
softmax_results
,
bmm1_results
,
pad_mask
,
input_lin_results
,
input_lin_results
,
inputs
,
inputs
,
input_weights
,
input_weights
,
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
View file @
3fe10b55
...
@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -63,7 +63,7 @@ std::vector<torch::Tensor> fwd_cuda(
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
auto
mask_options
=
act_options
.
dtype
(
torch
::
kUInt8
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
input_lin_results
=
torch
::
empty
({
q_seq_len
,
sequences
,
output_lin_dim
},
act_options
);
torch
::
Tensor
softmax
_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
bmm1
_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_results
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
act_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
dropout_mask
=
torch
::
empty
({
attn_batches
,
q_seq_len
,
k_seq_len
},
mask_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
torch
::
Tensor
matmul2_results
=
torch
::
empty
({
q_seq_len
,
attn_batches
,
head_dim
},
act_options
);
...
@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -75,7 +75,8 @@ std::vector<torch::Tensor> fwd_cuda(
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
void
*
v_lin_results_ptr
=
static_cast
<
void
*>
(
static_cast
<
half
*>
(
input_lin_results
.
data_ptr
())
+
2
*
head_dim
);
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
// Softmax Intermediate Result Ptr (used by Matmul1 -> Softmax)
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
bmm1_results_ptr
=
static_cast
<
void
*>
(
bmm1_results
.
data_ptr
());
void
*
dropout_results_ptr
=
static_cast
<
void
*>
(
dropout_results
.
data_ptr
());
char
a_layout_t
{
't'
};
char
a_layout_t
{
't'
};
char
a_layout_n
{
'n'
};
char
a_layout_n
{
'n'
};
...
@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -119,23 +120,29 @@ std::vector<torch::Tensor> fwd_cuda(
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
beta_zero
,
beta_zero
,
static_cast
<
half
*>
(
softmax
_results_ptr
),
static_cast
<
half
*>
(
bmm1
_results_ptr
),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
attn_batches
);
attn_batches
);
// Padded Softmax
// Padded Softmax
bool
softmax_success
=
false
;
bool
softmax_success
=
false
;
if
(
pad_mask
==
nullptr
)
{
if
(
is_training
)
{
softmax_success
=
dispatch_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_additive_masked_softmax_dropout
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
reinterpret_cast
<
half
*>
(
dropout_results_ptr
),
reinterpret_cast
<
const
half
*>
(
softmax_results_ptr
),
(
is_training
)
?
reinterpret_cast
<
uint8_t
*>
(
dropout_mask
.
data_ptr
<
uint8_t
>
())
:
nullptr
,
k_seq_len
,
reinterpret_cast
<
const
half
*>
(
bmm1_results_ptr
),
k_seq_len
,
pad_mask
,
attn_batches
*
q_seq_len
);
attn_batches
*
q_seq_len
*
q_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
,
attn_batches
*
q_seq_len
/
sequences
,
1.0
f
-
dropout_prob
,
stream
);
}
else
{
}
else
{
softmax_success
=
dispatch_additive_masked_softmax
<
half
,
half
,
float
>
(
softmax_success
=
dispatch_additive_masked_softmax
<
half
,
half
,
float
>
(
reinterpret_cast
<
half
*>
(
softmax
_results_ptr
),
reinterpret_cast
<
half
*>
(
dropout
_results_ptr
),
//this is actually softmax results, but making it consistent for the next function
reinterpret_cast
<
const
half
*>
(
softmax
_results_ptr
),
reinterpret_cast
<
const
half
*>
(
bmm1
_results_ptr
),
pad_mask
,
pad_mask
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
...
@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -143,14 +150,6 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches
*
q_seq_len
/
sequences
);
attn_batches
*
q_seq_len
/
sequences
);
}
}
if
(
is_training
)
{
//use at:: function so that C++ version generates the same random mask as python version
auto
dropout_tuple
=
at
::
_fused_dropout
(
softmax_results
,
1.0
f
-
dropout_prob
);
dropout_results
=
std
::
get
<
0
>
(
dropout_tuple
);
dropout_mask
=
std
::
get
<
1
>
(
dropout_tuple
);
}
// Matmul2
// Matmul2
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
a_layout_n
,
a_layout_n
,
...
@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -162,7 +161,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
static_cast
<
const
half
*>
(
v_lin_results_ptr
),
lead_dim
,
lead_dim
,
batch_stride
,
batch_stride
,
(
is_training
)
?
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
())
:
static_cast
<
const
half
*>
(
softmax_results
.
data_ptr
())
,
static_cast
<
const
half
*>
(
dropout_results
.
data_ptr
()),
k_seq_len
,
k_seq_len
,
k_seq_len
*
q_seq_len
,
k_seq_len
*
q_seq_len
,
beta_zero
,
beta_zero
,
...
@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
...
@@ -199,7 +198,7 @@ std::vector<torch::Tensor> fwd_cuda(
return
{
return
{
input_lin_results
,
input_lin_results
,
softmax
_results
,
bmm1
_results
,
dropout_results
,
dropout_results
,
dropout_mask
,
dropout_mask
,
matmul2_results
,
matmul2_results
,
...
@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -212,7 +211,8 @@ std::vector<torch::Tensor> bwd_cuda(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
matmul2_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
dropout_results
,
torch
::
Tensor
const
&
softmax_results
,
torch
::
Tensor
const
&
bmm1_results
,
torch
::
Tensor
const
&
pad_mask
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
input_lin_results
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
inputs
,
torch
::
Tensor
const
&
input_weights
,
torch
::
Tensor
const
&
input_weights
,
...
@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -350,15 +350,18 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
dispatch_masked_scale_softmax_backward
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward
_recompute
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*
const
>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
bmm1_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
pad_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
static_cast
<
uint8_t
const
*>
(
dropout_mask
.
data_ptr
()),
1.0
/
(
1.0
-
dropout_prob
),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
attn_batches
*
q_seq_len
/
sequences
,
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
...
apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
View file @
3fe10b55
...
@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -361,7 +361,7 @@ std::vector<torch::Tensor> bwd_cuda(
// Apply Dropout Mask and Scale by Dropout Probability
// Apply Dropout Mask and Scale by Dropout Probability
// Softmax Grad
// Softmax Grad
dispatch_masked_scale_softmax_backward
<
half
,
half
,
float
,
false
>
(
dispatch_masked_scale_softmax_backward
_stream
<
half
,
half
,
float
,
false
>
(
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
static_cast
<
half
*>
(
matmul2_grads
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
...
@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
...
@@ -369,7 +369,7 @@ std::vector<torch::Tensor> bwd_cuda(
1.0
/
(
1.0
-
dropout_prob
),
1.0
/
(
1.0
-
dropout_prob
),
k_seq_len
,
k_seq_len
,
k_seq_len
,
k_seq_len
,
attn_batches
*
q_seq_len
);
attn_batches
*
q_seq_len
,
stream
);
// Matmul1 Dgrad1
// Matmul1 Dgrad1
gemm_switch_fp32accum
(
state
,
gemm_switch_fp32accum
(
state
,
...
...
apex/contrib/csrc/multihead_attn/softmax.h
View file @
3fe10b55
This diff is collapsed.
Click to expand it.
apex/contrib/multihead_attn/fast_self_multihead_attn_func.py
View file @
3fe10b55
...
@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
...
@@ -11,6 +11,7 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t
=
torch
.
tensor
([
dropout_prob
])
dropout_prob_t
=
torch
.
tensor
([
dropout_prob
])
null_tensor
=
torch
.
tensor
([])
null_tensor
=
torch
.
tensor
([])
use_mask
=
(
pad_mask
is
not
None
)
use_mask
=
(
pad_mask
is
not
None
)
mask_additive_t
=
torch
.
tensor
([
mask_additive
])
if
use_biases_t
[
0
]:
if
use_biases_t
[
0
]:
if
not
mask_additive
:
if
not
mask_additive
:
...
@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) :
...
@@ -32,9 +33,24 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases
,
\
output_biases
,
\
pad_mask
if
use_mask
else
null_tensor
,
\
pad_mask
if
use_mask
else
null_tensor
,
\
dropout_prob
)
dropout_prob
)
ctx
.
save_for_backward
(
use_biases_t
,
\
heads_t
,
\
matmul2_results
,
\
dropout_results
,
\
softmax_results
,
\
null_tensor
,
\
null_tensor
,
\
mask_additive_t
,
\
input_lin_results
,
\
inputs
,
\
input_weights
,
\
output_weights
,
\
dropout_mask
,
\
dropout_prob_t
)
else
:
else
:
input_lin_results
,
\
input_lin_results
,
\
softmax
_results
,
\
bmm1
_results
,
\
dropout_results
,
\
dropout_results
,
\
dropout_mask
,
\
dropout_mask
,
\
matmul2_results
,
\
matmul2_results
,
\
...
@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
...
@@ -51,6 +67,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_biases
,
\
output_biases
,
\
pad_mask
if
use_mask
else
null_tensor
,
\
pad_mask
if
use_mask
else
null_tensor
,
\
dropout_prob
)
dropout_prob
)
ctx
.
save_for_backward
(
use_biases_t
,
\
heads_t
,
\
matmul2_results
,
\
dropout_results
,
\
null_tensor
,
\
bmm1_results
,
\
pad_mask
,
\
mask_additive_t
,
\
input_lin_results
,
\
inputs
,
\
input_weights
,
\
output_weights
,
\
dropout_mask
,
\
dropout_prob_t
)
else
:
else
:
...
@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
...
@@ -70,20 +100,20 @@ class FastSelfAttnFunc(torch.autograd.Function) :
output_weights
,
\
output_weights
,
\
pad_mask
if
use_mask
else
null_tensor
,
\
pad_mask
if
use_mask
else
null_tensor
,
\
dropout_prob
)
dropout_prob
)
ctx
.
save_for_backward
(
use_biases_t
,
\
ctx
.
save_for_backward
(
use_biases_t
,
\
heads_t
,
\
heads_t
,
\
matmul2_results
,
\
matmul2
_results
,
\
dropout
_results
,
\
dropout
_results
,
\
softmax
_results
,
\
softmax_results
,
\
null_tensor
,
\
input_lin_results
,
\
null_tensor
,
\
inputs
,
\
mask_additive_t
,
\
input_
weights
,
\
input_
lin_results
,
\
output_weights
,
\
inputs
,
\
dropout_mask
,
\
input_weights
,
\
dropout_prob_t
)
output_weights
,
\
dropout_mask
,
\
dropout_prob_t
)
return
outputs
.
detach
()
return
outputs
.
detach
()
@
staticmethod
@
staticmethod
...
@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) :
...
@@ -93,6 +123,9 @@ class FastSelfAttnFunc(torch.autograd.Function) :
matmul2_results
,
\
matmul2_results
,
\
dropout_results
,
\
dropout_results
,
\
softmax_results
,
\
softmax_results
,
\
bmm1_results
,
\
pad_mask
,
\
mask_additive_t
,
\
input_lin_results
,
\
input_lin_results
,
\
inputs
,
\
inputs
,
\
input_weights
,
\
input_weights
,
\
...
@@ -101,24 +134,45 @@ class FastSelfAttnFunc(torch.autograd.Function) :
...
@@ -101,24 +134,45 @@ class FastSelfAttnFunc(torch.autograd.Function) :
dropout_prob_t
=
ctx
.
saved_tensors
dropout_prob_t
=
ctx
.
saved_tensors
if
use_biases_t
[
0
]:
if
use_biases_t
[
0
]:
input_grads
,
\
if
not
mask_additive_t
[
0
]:
input_weight_grads
,
\
input_grads
,
\
output_weight_grads
,
\
input_weight_grads
,
\
input_bias_grads
,
\
output_weight_grads
,
\
output_bias_grads
=
\
input_bias_grads
,
\
fast_self_multihead_attn_bias
.
backward
(
\
output_bias_grads
=
\
heads_t
[
0
],
\
fast_self_multihead_attn_bias
.
backward
(
\
output_grads
,
\
heads_t
[
0
],
\
matmul2_results
,
\
output_grads
,
\
dropout_results
,
\
matmul2_results
,
\
softmax_results
,
\
dropout_results
,
\
input_lin_results
,
\
softmax_results
,
\
inputs
,
\
input_lin_results
,
\
input_weights
,
\
inputs
,
\
output_weights
,
\
input_weights
,
\
dropout_mask
,
\
output_weights
,
\
dropout_prob_t
[
0
])
dropout_mask
,
\
dropout_prob_t
[
0
])
else
:
input_grads
,
\
input_weight_grads
,
\
output_weight_grads
,
\
input_bias_grads
,
\
output_bias_grads
=
\
fast_self_multihead_attn_bias_additive_mask
.
backward
(
\
heads_t
[
0
],
\
output_grads
,
\
matmul2_results
,
\
dropout_results
,
\
bmm1_results
,
\
pad_mask
,
\
input_lin_results
,
\
inputs
,
\
input_weights
,
\
output_weights
,
\
dropout_mask
,
\
dropout_prob_t
[
0
])
else
:
else
:
input_bias_grads
=
None
input_bias_grads
=
None
output_bias_grads
=
None
output_bias_grads
=
None
...
...
apex/contrib/test/multihead_attn/test_fast_self_multihead_attn_bias.py
0 → 100644
View file @
3fe10b55
import
torch
import
unittest
from
apex.contrib.multihead_attn
import
SelfMultiheadAttn
class
SelfMultiheadAttnTest
(
unittest
.
TestCase
):
def
setUp
(
self
,
seed
=
1234
):
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
self
.
seq_length
=
80
self
.
sequences
=
10
self
.
hidden_dim
=
1024
self
.
heads
=
16
self
.
dropout_prob
=
0.0
self
.
ref_layer
=
SelfMultiheadAttn
(
self
.
hidden_dim
,
self
.
heads
,
dropout
=
self
.
dropout_prob
,
bias
=
True
,
include_norm_add
=
False
,
separate_qkv_params
=
True
,
mask_additive
=
True
,
impl
=
'default'
)
self
.
ref_layer
.
cuda
().
half
()
self
.
ref_layer
.
reset_parameters
()
self
.
ref_inputs
=
torch
.
randn
(
self
.
seq_length
,
self
.
sequences
,
self
.
hidden_dim
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
)).
requires_grad_
(
True
)
# Reset seed so parameters are identical
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
self
.
tst_layer
=
SelfMultiheadAttn
(
self
.
hidden_dim
,
self
.
heads
,
dropout
=
self
.
dropout_prob
,
bias
=
True
,
include_norm_add
=
False
,
separate_qkv_params
=
True
,
mask_additive
=
True
,
impl
=
'fast'
)
self
.
tst_layer
.
cuda
().
half
()
self
.
tst_layer
.
reset_parameters
()
self
.
tst_inputs
=
torch
.
randn
(
self
.
seq_length
,
self
.
sequences
,
self
.
hidden_dim
,
dtype
=
torch
.
float16
,
device
=
torch
.
device
(
"cuda"
)).
requires_grad_
(
True
)
def
test_self_multihead_attn_additive_mask
(
self
)
:
grads
=
torch
.
randn_like
(
self
.
tst_inputs
)
mask
=
((
torch
.
randn
(
self
.
sequences
,
self
.
seq_length
)
>
0
)
*
-
10000.0
).
half
().
cuda
()
ref_outputs
,
_
=
self
.
ref_layer
.
forward
(
self
.
ref_inputs
,
self
.
ref_inputs
,
self
.
ref_inputs
,
key_padding_mask
=
mask
,
need_weights
=
False
,
attn_mask
=
None
,
is_training
=
True
)
tst_outputs
,
_
=
self
.
tst_layer
.
forward
(
self
.
tst_inputs
,
self
.
tst_inputs
,
self
.
tst_inputs
,
key_padding_mask
=
mask
,
need_weights
=
False
,
attn_mask
=
None
,
is_training
=
True
)
self
.
ref_inputs
.
backward
(
grads
)
self
.
tst_inputs
.
backward
(
grads
)
self
.
assertTrue
(
torch
.
allclose
(
self
.
ref_inputs
,
self
.
tst_inputs
,
atol
=
1e-5
,
rtol
=
1e-5
))
self
.
assertTrue
(
torch
.
allclose
(
ref_outputs
,
tst_outputs
,
atol
=
1e-3
,
rtol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
self
.
ref_inputs
.
grad
,
self
.
tst_inputs
.
grad
,
atol
=
1e-3
,
rtol
=
1e-3
))
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment