Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
60c4081b
Commit
60c4081b
authored
Apr 06, 2018
by
Myle Ott
Browse files
More improvements to weight init and FP16 support
parent
36e360d9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
18 deletions
+21
-18
fairseq/modules/multihead_attention.py
fairseq/modules/multihead_attention.py
+16
-18
fairseq/utils.py
fairseq/utils.py
+5
-0
No files found.
fairseq/modules/multihead_attention.py
View file @
60c4081b
...
...
@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
Parameter
...
...
@@ -30,20 +28,21 @@ class MultiheadAttention(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
_mask
=
None
self
.
in_proj_weight
=
Parameter
(
torch
.
Tensor
(
3
*
self
.
embed_dim
,
self
.
embed_dim
))
self
.
in_proj_weight
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
,
embed_dim
))
if
bias
:
self
.
in_proj_bias
=
Parameter
(
torch
.
Tensor
(
3
*
self
.
embed_dim
))
self
.
in_proj_bias
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
))
else
:
self
.
register_parameter
(
'in_proj_bias'
,
None
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
nn
.
init
.
xavier_uniform
(
self
.
in_proj_weight
.
data
)
nn
.
init
.
xavier_uniform
(
self
.
out_proj
.
weight
.
data
)
nn
.
init
.
xavier_uniform
(
self
.
in_proj_weight
)
nn
.
init
.
xavier_uniform
(
self
.
out_proj
.
weight
)
if
self
.
in_proj_bias
is
not
None
:
self
.
in_proj_bias
.
data
.
zero_
()
nn
.
init
.
constant
(
self
.
in_proj_bias
,
0.
)
nn
.
init
.
constant
(
self
.
out_proj
.
bias
,
0.
)
def
forward
(
self
,
query
,
key
,
value
,
mask_future_timesteps
=
False
,
key_padding_mask
=
None
,
incremental_state
=
None
,
...
...
@@ -125,10 +124,10 @@ class MultiheadAttention(nn.Module):
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
masked_fill
(
attn_weights
=
attn_weights
.
float
().
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
),
-
math
.
inf
,
)
float
(
'-
inf
'
)
,
)
.
type_as
(
attn_weights
)
# FP16 support: cast to float and back
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
...
@@ -178,14 +177,13 @@ class MultiheadAttention(nn.Module):
def
buffered_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
-
1
)
if
self
.
_mask
is
None
:
self
.
_mask
=
torch
.
triu
(
tensor
.
new
(
dim
,
dim
)
.
fill_
(
-
math
.
inf
),
1
)
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
if
self
.
_mask
.
size
(
0
)
<
dim
:
self
.
_mask
=
torch
.
triu
(
self
.
_mask
.
resize_
(
dim
,
dim
)
.
fill_
(
-
math
.
inf
),
1
)
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_mask
[:
dim
,
:
dim
]
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
...
...
@@ -194,10 +192,10 @@ class MultiheadAttention(nn.Module):
def
_get_input_buffer
(
self
,
incremental_state
):
return
utils
.
get_incremental_state
(
self
,
incremental_state
,
'attn_state'
,
)
or
{}
self
,
incremental_state
,
'attn_state'
,
)
or
{}
def
_set_input_buffer
(
self
,
incremental_state
,
buffer
):
utils
.
set_incremental_state
(
...
...
fairseq/utils.py
View file @
60c4081b
...
...
@@ -375,3 +375,8 @@ def item(tensor):
if
hasattr
(
tensor
,
'__getitem__'
):
return
tensor
[
0
]
return
tensor
def
fill_with_neg_inf
(
t
):
"""FP16-compatible function that fills a tensor with -inf."""
return
t
.
float
().
fill_
(
float
(
'-inf'
)).
type_as
(
t
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment