Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
60c4081b
"vscode:/vscode.git/clone" did not exist on "a9cb08af398c9fe06d2d62bd12942458d5dba151"
Commit
60c4081b
authored
Apr 06, 2018
by
Myle Ott
Browse files
More improvements to weight init and FP16 support
parent
36e360d9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
18 deletions
+21
-18
fairseq/modules/multihead_attention.py
fairseq/modules/multihead_attention.py
+16
-18
fairseq/utils.py
fairseq/utils.py
+5
-0
No files found.
fairseq/modules/multihead_attention.py
View file @
60c4081b
...
...
@@ -5,8 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
import
torch
from
torch
import
nn
from
torch.nn
import
Parameter
...
...
@@ -30,20 +28,21 @@ class MultiheadAttention(nn.Module):
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
_mask
=
None
self
.
in_proj_weight
=
Parameter
(
torch
.
Tensor
(
3
*
self
.
embed_dim
,
self
.
embed_dim
))
self
.
in_proj_weight
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
,
embed_dim
))
if
bias
:
self
.
in_proj_bias
=
Parameter
(
torch
.
Tensor
(
3
*
self
.
embed_dim
))
self
.
in_proj_bias
=
Parameter
(
torch
.
Tensor
(
3
*
embed_dim
))
else
:
self
.
register_parameter
(
'in_proj_bias'
,
None
)
self
.
out_proj
=
nn
.
Linear
(
self
.
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
nn
.
init
.
xavier_uniform
(
self
.
in_proj_weight
.
data
)
nn
.
init
.
xavier_uniform
(
self
.
out_proj
.
weight
.
data
)
nn
.
init
.
xavier_uniform
(
self
.
in_proj_weight
)
nn
.
init
.
xavier_uniform
(
self
.
out_proj
.
weight
)
if
self
.
in_proj_bias
is
not
None
:
self
.
in_proj_bias
.
data
.
zero_
()
nn
.
init
.
constant
(
self
.
in_proj_bias
,
0.
)
nn
.
init
.
constant
(
self
.
out_proj
.
bias
,
0.
)
def
forward
(
self
,
query
,
key
,
value
,
mask_future_timesteps
=
False
,
key_padding_mask
=
None
,
incremental_state
=
None
,
...
...
@@ -125,10 +124,10 @@ class MultiheadAttention(nn.Module):
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
masked_fill
(
attn_weights
=
attn_weights
.
float
().
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
),
-
math
.
inf
,
)
float
(
'-
inf
'
)
,
)
.
type_as
(
attn_weights
)
# FP16 support: cast to float and back
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
...
...
@@ -178,14 +177,13 @@ class MultiheadAttention(nn.Module):
def
buffered_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
-
1
)
if
self
.
_mask
is
None
:
self
.
_mask
=
torch
.
triu
(
tensor
.
new
(
dim
,
dim
)
.
fill_
(
-
math
.
inf
),
1
)
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
if
self
.
_mask
.
size
(
0
)
<
dim
:
self
.
_mask
=
torch
.
triu
(
self
.
_mask
.
resize_
(
dim
,
dim
)
.
fill_
(
-
math
.
inf
),
1
)
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_mask
[:
dim
,
:
dim
]
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
...
...
@@ -194,10 +192,10 @@ class MultiheadAttention(nn.Module):
def
_get_input_buffer
(
self
,
incremental_state
):
return
utils
.
get_incremental_state
(
self
,
incremental_state
,
'attn_state'
,
)
or
{}
self
,
incremental_state
,
'attn_state'
,
)
or
{}
def
_set_input_buffer
(
self
,
incremental_state
,
buffer
):
utils
.
set_incremental_state
(
...
...
fairseq/utils.py
View file @
60c4081b
...
...
@@ -375,3 +375,8 @@ def item(tensor):
if
hasattr
(
tensor
,
'__getitem__'
):
return
tensor
[
0
]
return
tensor
def
fill_with_neg_inf
(
t
):
"""FP16-compatible function that fills a tensor with -inf."""
return
t
.
float
().
fill_
(
float
(
'-inf'
)).
type_as
(
t
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment