Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
535ca991
Commit
535ca991
authored
Sep 24, 2018
by
Myle Ott
Browse files
Merge internal changes
parent
28069cf4
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
31 deletions
+69
-31
fairseq/distributed_utils.py
fairseq/distributed_utils.py
+21
-18
fairseq/models/transformer.py
fairseq/models/transformer.py
+23
-1
fairseq/modules/multihead_attention.py
fairseq/modules/multihead_attention.py
+16
-5
fairseq/modules/sinusoidal_positional_embedding.py
fairseq/modules/sinusoidal_positional_embedding.py
+8
-6
fairseq/tokenizer.py
fairseq/tokenizer.py
+1
-1
No files found.
fairseq/distributed_utils.py
View file @
535ca991
...
...
@@ -9,8 +9,7 @@ from collections import namedtuple
import
pickle
import
torch
from
torch
import
distributed
,
nn
from
torch.distributed
import
group
from
torch
import
nn
from
fairseq
import
utils
...
...
@@ -33,6 +32,16 @@ else:
c10d_status
=
C10dStatus
(
has_c10d
=
False
,
is_default
=
False
)
if
c10d_status
.
is_default
:
import
torch.distributed
as
dist_c10d
import
torch.distributed.deprecated
as
dist_no_c10d
elif
c10d_status
.
has_c10d
:
import
torch.distributed.c10d
as
dist_c10d
import
torch.distributed
as
dist_no_c10d
else
:
import
torch.distributed
as
dist_no_c10d
def
distributed_init
(
args
):
if
args
.
distributed_world_size
==
1
:
raise
ValueError
(
'Cannot initialize distributed with distributed_world_size=1'
)
...
...
@@ -44,15 +53,9 @@ def distributed_init(args):
args
.
distributed_rank
,
args
.
distributed_init_method
),
flush
=
True
)
if
_use_c10d
[
0
]:
if
c10d_status
.
is_default
:
init_fn
=
distributed
.
init_process_group
else
:
init_fn
=
distributed
.
c10d
.
init_process_group
else
:
if
c10d_status
.
is_default
:
init_fn
=
distributed
.
deprecated
.
init_process_group
init_fn
=
dist_c10d
.
init_process_group
else
:
init_fn
=
dist
ribute
d
.
init_process_group
init_fn
=
dist
_no_c10
d
.
init_process_group
init_fn
(
backend
=
args
.
distributed_backend
,
...
...
@@ -83,32 +86,32 @@ def suppress_output():
def
get_rank
():
if
_use_c10d
[
0
]:
return
dist
ributed
.
c10d
.
get_rank
()
return
dist
_
c10d
.
get_rank
()
else
:
return
dist
ribute
d
.
get_rank
()
return
dist
_no_c10
d
.
get_rank
()
def
get_world_size
():
if
_use_c10d
[
0
]:
return
dist
ributed
.
c10d
.
get_world_size
()
return
dist
_
c10d
.
get_world_size
()
else
:
return
dist
ribute
d
.
get_world_size
()
return
dist
_no_c10
d
.
get_world_size
()
def
get_default_group
():
if
_use_c10d
[
0
]:
return
dist
ributed
.
c10d
.
group
.
WORLD
return
dist
_
c10d
.
group
.
WORLD
else
:
return
dist
ribute
d
.
group
.
WORLD
return
dist
_no_c10
d
.
group
.
WORLD
def
all_reduce
(
tensor
,
group
=
None
):
if
group
is
None
:
group
=
get_default_group
()
if
_use_c10d
[
0
]:
return
dist
ributed
.
c10d
.
all_reduce
(
tensor
,
group
=
group
)
return
dist
_
c10d
.
all_reduce
(
tensor
,
group
=
group
)
else
:
return
dist
ribute
d
.
all_reduce
(
tensor
,
group
=
group
)
return
dist
_no_c10
d
.
all_reduce
(
tensor
,
group
=
group
)
def
all_gather_list
(
data
,
group
=
None
,
max_size
=
16384
):
...
...
fairseq/models/transformer.py
View file @
535ca991
...
...
@@ -627,7 +627,13 @@ class TransformerDecoderLayer(nn.Module):
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
need_attn
=
True
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
,
self_attn_mask
=
None
,
self
.
onnx_trace
=
False
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
def
forward
(
self
,
x
,
encoder_out
,
encoder_padding_mask
,
incremental_state
,
prev_self_attn_state
=
None
,
prev_attn_state
=
None
,
self_attn_mask
=
None
,
self_attn_padding_mask
=
None
):
"""
Args:
...
...
@@ -640,6 +646,12 @@ class TransformerDecoderLayer(nn.Module):
"""
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
self_attn_layer_norm
,
x
,
before
=
True
)
if
prev_self_attn_state
is
not
None
:
if
incremental_state
is
None
:
incremental_state
=
{}
prev_key
,
prev_value
=
prev_self_attn_state
saved_state
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
}
self
.
self_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
...
...
@@ -657,6 +669,12 @@ class TransformerDecoderLayer(nn.Module):
if
self
.
encoder_attn
is
not
None
:
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
encoder_attn_layer_norm
,
x
,
before
=
True
)
if
prev_attn_state
is
not
None
:
if
incremental_state
is
None
:
incremental_state
=
{}
prev_key
,
prev_value
=
prev_attn_state
saved_state
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
}
self
.
encoder_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
...
...
@@ -678,6 +696,10 @@ class TransformerDecoderLayer(nn.Module):
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
final_layer_norm
,
x
,
after
=
True
)
if
self
.
onnx_trace
:
saved_state
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
self_attn_state
=
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
]
return
x
,
attn
,
self_attn_state
return
x
,
attn
def
maybe_layer_norm
(
self
,
layer_norm
,
x
,
before
=
False
,
after
=
False
):
...
...
fairseq/modules/multihead_attention.py
View file @
535ca991
...
...
@@ -45,6 +45,11 @@ class MultiheadAttention(nn.Module):
self
.
reset_parameters
()
self
.
onnx_trace
=
False
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
def
reset_parameters
(
self
):
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
...
...
@@ -94,9 +99,7 @@ class MultiheadAttention(nn.Module):
q
=
self
.
in_proj_q
(
query
)
if
key
is
None
:
assert
value
is
None
# this will allow us to concat it with previous value and get
# just get the previous value
k
=
v
=
q
.
new
(
0
)
k
=
v
=
None
else
:
k
,
v
=
self
.
in_proj_kv
(
key
)
else
:
...
...
@@ -106,12 +109,20 @@ class MultiheadAttention(nn.Module):
q
*=
self
.
scaling
if
saved_state
is
not
None
:
if
'prev_key'
in
saved_state
:
if
static_kv
:
k
=
saved_state
[
'prev_key'
]
else
:
k
=
torch
.
cat
((
saved_state
[
'prev_key'
],
k
),
dim
=
0
)
if
'prev_value'
in
saved_state
:
if
static_kv
:
v
=
saved_state
[
'prev_value'
]
else
:
v
=
torch
.
cat
((
saved_state
[
'prev_value'
],
v
),
dim
=
0
)
saved_state
[
'prev_key'
]
=
k
saved_state
[
'prev_value'
]
=
v
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
if
self
.
bias_k
is
not
None
:
...
...
fairseq/modules/sinusoidal_positional_embedding.py
View file @
535ca991
...
...
@@ -9,6 +9,7 @@ import math
import
torch
import
torch.nn
as
nn
import
torch.onnx.operators
from
fairseq
import
utils
...
...
@@ -55,12 +56,12 @@ class SinusoidalPositionalEmbedding(nn.Module):
emb
[
padding_idx
,
:]
=
0
return
emb
def
forward
(
self
,
input
,
incremental_state
=
None
):
def
forward
(
self
,
input
,
incremental_state
=
None
,
timestep
=
None
):
"""Input is expected to be of size [bsz x seqlen]."""
# recompute/expand embeddings if needed
bsz
,
seq_len
=
input
.
size
()
bsz
,
seq_len
=
torch
.
onnx
.
operators
.
shape_as_tensor
(
input
)
max_pos
=
self
.
padding_idx
+
1
+
seq_len
if
self
.
weights
is
None
or
max_pos
>
self
.
weights
.
size
(
0
):
# recompute/expand embeddings if needed
self
.
weights
=
SinusoidalPositionalEmbedding
.
get_embedding
(
max_pos
,
self
.
embedding_dim
,
...
...
@@ -70,12 +71,13 @@ class SinusoidalPositionalEmbedding(nn.Module):
if
incremental_state
is
not
None
:
# positions is the same for every token when decoding a single step
return
self
.
weights
[
self
.
padding_idx
+
seq_len
,
:].
expand
(
bsz
,
1
,
-
1
)
pos
=
(
timestep
.
int
()
+
1
).
long
()
if
timestep
is
not
None
else
seq_len
if
self
.
onnx_trace
:
return
self
.
weights
[
self
.
padding_idx
+
pos
,
:].
unsqueeze
(
1
).
repeat
(
bsz
,
1
,
1
)
return
self
.
weights
[
self
.
padding_idx
+
pos
,
:].
expand
(
bsz
,
1
,
-
1
)
positions
=
utils
.
make_positions
(
input
,
self
.
padding_idx
,
self
.
left_pad
,
self
.
onnx_trace
)
if
self
.
onnx_trace
:
bsz
=
torch
.
onnx
.
operators
.
shape_as_tensor
(
input
)[
0
]
seq_len
=
torch
.
onnx
.
operators
.
shape_as_tensor
(
input
)[
1
]
flat_embeddings
=
self
.
weights
.
detach
().
index_select
(
0
,
positions
.
view
(
-
1
))
embedding_shape
=
torch
.
cat
((
bsz
.
view
(
1
),
seq_len
.
view
(
1
),
torch
.
LongTensor
([
-
1
])))
embeddings
=
torch
.
onnx
.
operators
.
reshape_from_tensor_shape
(
flat_embeddings
,
embedding_shape
)
...
...
fairseq/tokenizer.py
View file @
535ca991
...
...
@@ -11,7 +11,7 @@ import os, re
import
torch
from
multiprocessing
import
Pool
SPACE_NORMALIZER
=
re
.
compile
(
"\s+"
)
SPACE_NORMALIZER
=
re
.
compile
(
r
"\s+"
)
def
tokenize_line
(
line
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment