Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4140 additions
and
0 deletions
+4140
-0
PyTorch/NLP/new-Transformer/fairseq/iterative_refinement_generator.py
...new-Transformer/fairseq/iterative_refinement_generator.py
+359
-0
PyTorch/NLP/new-Transformer/fairseq/logging/__init__.py
PyTorch/NLP/new-Transformer/fairseq/logging/__init__.py
+0
-0
PyTorch/NLP/new-Transformer/fairseq/logging/meters.py
PyTorch/NLP/new-Transformer/fairseq/logging/meters.py
+321
-0
PyTorch/NLP/new-Transformer/fairseq/logging/metrics.py
PyTorch/NLP/new-Transformer/fairseq/logging/metrics.py
+316
-0
PyTorch/NLP/new-Transformer/fairseq/logging/progress_bar.py
PyTorch/NLP/new-Transformer/fairseq/logging/progress_bar.py
+582
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/__init__.py
...ch/NLP/new-Transformer/fairseq/model_parallel/__init__.py
+6
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/criterions/__init__.py
...Transformer/fairseq/model_parallel/criterions/__init__.py
+14
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
...model_parallel/criterions/vocab_parallel_cross_entropy.py
+87
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/megatron_trainer.py
...ew-Transformer/fairseq/model_parallel/megatron_trainer.py
+75
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/__init__.py
...new-Transformer/fairseq/model_parallel/models/__init__.py
+20
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py
...parallel/models/pipeline_parallel_transformer/__init__.py
+6
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
...l_parallel/models/pipeline_parallel_transformer/layers.py
+600
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
...el_parallel/models/pipeline_parallel_transformer/model.py
+789
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/roberta/__init__.py
...sformer/fairseq/model_parallel/models/roberta/__init__.py
+6
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/roberta/model.py
...ransformer/fairseq/model_parallel/models/roberta/model.py
+225
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/transformer.py
...-Transformer/fairseq/model_parallel/models/transformer.py
+121
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/transformer_lm.py
...ansformer/fairseq/model_parallel/models/transformer_lm.py
+169
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/modules/__init__.py
...ew-Transformer/fairseq/model_parallel/modules/__init__.py
+17
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/modules/multihead_attention.py
...mer/fairseq/model_parallel/modules/multihead_attention.py
+349
-0
PyTorch/NLP/new-Transformer/fairseq/model_parallel/modules/transformer_layer.py
...ormer/fairseq/model_parallel/modules/transformer_layer.py
+78
-0
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/new-Transformer/fairseq/iterative_refinement_generator.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
collections
import
namedtuple
import
numpy
as
np
import
torch
from
fairseq
import
utils
DecoderOut
=
namedtuple
(
"IterativeRefinementDecoderOut"
,
[
"output_tokens"
,
"output_scores"
,
"attn"
,
"step"
,
"max_step"
,
"history"
],
)
class
IterativeRefinementGenerator
(
object
):
def
__init__
(
self
,
tgt_dict
,
models
=
None
,
eos_penalty
=
0.0
,
max_iter
=
10
,
max_ratio
=
2
,
beam_size
=
1
,
decoding_format
=
None
,
retain_dropout
=
False
,
adaptive
=
True
,
retain_history
=
False
,
reranking
=
False
,
):
"""
Generates translations based on iterative refinement.
Args:
tgt_dict: target dictionary
eos_penalty: if > 0.0, it penalized early-stopping in decoding
max_iter: maximum number of refinement iterations
max_ratio: generate sequences of maximum length ax, where x is the source length
decoding_format: decoding mode in {'unigram', 'ensemble', 'vote', 'dp', 'bs'}
retain_dropout: retaining dropout in the inference
adaptive: decoding with early stop
"""
self
.
bos
=
tgt_dict
.
bos
()
self
.
pad
=
tgt_dict
.
pad
()
self
.
unk
=
tgt_dict
.
unk
()
self
.
eos
=
tgt_dict
.
eos
()
self
.
vocab_size
=
len
(
tgt_dict
)
self
.
eos_penalty
=
eos_penalty
self
.
max_iter
=
max_iter
self
.
max_ratio
=
max_ratio
self
.
beam_size
=
beam_size
self
.
reranking
=
reranking
self
.
decoding_format
=
decoding_format
self
.
retain_dropout
=
retain_dropout
self
.
retain_history
=
retain_history
self
.
adaptive
=
adaptive
self
.
models
=
models
def
generate_batched_itr
(
self
,
data_itr
,
maxlen_a
=
None
,
maxlen_b
=
None
,
cuda
=
False
,
timer
=
None
,
prefix_size
=
0
,
):
"""Iterate over a batched dataset and yield individual translations.
Args:
maxlen_a/b: generate sequences of maximum length ax + b,
where x is the source sentence length.
cuda: use GPU for generation
timer: StopwatchMeter for timing generations.
"""
for
sample
in
data_itr
:
if
"net_input"
not
in
sample
:
continue
if
timer
is
not
None
:
timer
.
start
()
with
torch
.
no_grad
():
hypos
=
self
.
generate
(
self
.
models
,
sample
,
prefix_tokens
=
sample
[
"target"
][:,
:
prefix_size
]
if
prefix_size
>
0
else
None
,
)
if
timer
is
not
None
:
timer
.
stop
(
sample
[
"ntokens"
])
for
i
,
id
in
enumerate
(
sample
[
"id"
]):
# remove padding
src
=
utils
.
strip_pad
(
sample
[
"net_input"
][
"src_tokens"
][
i
,
:],
self
.
pad
)
ref
=
utils
.
strip_pad
(
sample
[
"target"
][
i
,
:],
self
.
pad
)
yield
id
,
src
,
ref
,
hypos
[
i
]
@
torch
.
no_grad
()
def
generate
(
self
,
models
,
sample
,
prefix_tokens
=
None
,
constraints
=
None
):
if
constraints
is
not
None
:
raise
NotImplementedError
(
"Constrained decoding with the IterativeRefinementGenerator is not supported"
)
# TODO: iterative refinement generator does not support ensemble for now.
if
not
self
.
retain_dropout
:
for
model
in
models
:
model
.
eval
()
model
,
reranker
=
models
[
0
],
None
if
self
.
reranking
:
assert
len
(
models
)
>
1
,
"Assuming the last checkpoint is the reranker"
assert
(
self
.
beam_size
>
1
),
"Reranking requires multiple translation for each example"
reranker
=
models
[
-
1
]
models
=
models
[:
-
1
]
if
len
(
models
)
>
1
and
hasattr
(
model
,
"enable_ensemble"
):
assert
model
.
allow_ensemble
,
"{} does not support ensembling"
.
format
(
model
.
__class__
.
__name__
)
model
.
enable_ensemble
(
models
)
# TODO: better encoder inputs?
src_tokens
=
sample
[
"net_input"
][
"src_tokens"
]
src_lengths
=
sample
[
"net_input"
][
"src_lengths"
]
bsz
,
src_len
=
src_tokens
.
size
()
# initialize
encoder_out
=
model
.
forward_encoder
([
src_tokens
,
src_lengths
])
prev_decoder_out
=
model
.
initialize_output_tokens
(
encoder_out
,
src_tokens
)
if
self
.
beam_size
>
1
:
assert
(
model
.
allow_length_beam
),
"{} does not support decoding with length beam."
.
format
(
model
.
__class__
.
__name__
)
# regenerate data based on length-beam
length_beam_order
=
(
utils
.
new_arange
(
src_tokens
,
self
.
beam_size
,
bsz
).
t
().
reshape
(
-
1
)
)
encoder_out
=
model
.
encoder
.
reorder_encoder_out
(
encoder_out
,
length_beam_order
)
prev_decoder_out
=
model
.
regenerate_length_beam
(
prev_decoder_out
,
self
.
beam_size
)
bsz
=
bsz
*
self
.
beam_size
sent_idxs
=
torch
.
arange
(
bsz
)
prev_output_tokens
=
prev_decoder_out
.
output_tokens
.
clone
()
if
self
.
retain_history
:
prev_decoder_out
=
prev_decoder_out
.
_replace
(
history
=
[
prev_output_tokens
])
finalized
=
[[]
for
_
in
range
(
bsz
)]
def
is_a_loop
(
x
,
y
,
s
,
a
):
b
,
l_x
,
l_y
=
x
.
size
(
0
),
x
.
size
(
1
),
y
.
size
(
1
)
if
l_x
>
l_y
:
y
=
torch
.
cat
([
y
,
x
.
new_zeros
(
b
,
l_x
-
l_y
).
fill_
(
self
.
pad
)],
1
)
s
=
torch
.
cat
([
s
,
s
.
new_zeros
(
b
,
l_x
-
l_y
)],
1
)
if
a
is
not
None
:
a
=
torch
.
cat
([
a
,
a
.
new_zeros
(
b
,
l_x
-
l_y
,
a
.
size
(
2
))],
1
)
elif
l_x
<
l_y
:
x
=
torch
.
cat
([
x
,
y
.
new_zeros
(
b
,
l_y
-
l_x
).
fill_
(
self
.
pad
)],
1
)
return
(
x
==
y
).
all
(
1
),
y
,
s
,
a
def
finalized_hypos
(
step
,
prev_out_token
,
prev_out_score
,
prev_out_attn
):
cutoff
=
prev_out_token
.
ne
(
self
.
pad
)
tokens
=
prev_out_token
[
cutoff
]
if
prev_out_score
is
None
:
scores
,
score
=
None
,
None
else
:
scores
=
prev_out_score
[
cutoff
]
score
=
scores
.
mean
()
if
prev_out_attn
is
None
:
hypo_attn
,
alignment
=
None
,
None
else
:
hypo_attn
=
prev_out_attn
[
cutoff
]
alignment
=
hypo_attn
.
max
(
dim
=
1
)[
1
]
return
{
"steps"
:
step
,
"tokens"
:
tokens
,
"positional_scores"
:
scores
,
"score"
:
score
,
"hypo_attn"
:
hypo_attn
,
"alignment"
:
alignment
,
}
for
step
in
range
(
self
.
max_iter
+
1
):
decoder_options
=
{
"eos_penalty"
:
self
.
eos_penalty
,
"max_ratio"
:
self
.
max_ratio
,
"decoding_format"
:
self
.
decoding_format
,
}
prev_decoder_out
=
prev_decoder_out
.
_replace
(
step
=
step
,
max_step
=
self
.
max_iter
+
1
,
)
decoder_out
=
model
.
forward_decoder
(
prev_decoder_out
,
encoder_out
,
**
decoder_options
)
if
self
.
adaptive
:
# terminate if there is a loop
terminated
,
out_tokens
,
out_scores
,
out_attn
=
is_a_loop
(
prev_output_tokens
,
decoder_out
.
output_tokens
,
decoder_out
.
output_scores
,
decoder_out
.
attn
,
)
decoder_out
=
decoder_out
.
_replace
(
output_tokens
=
out_tokens
,
output_scores
=
out_scores
,
attn
=
out_attn
,
)
else
:
terminated
=
decoder_out
.
output_tokens
.
new_zeros
(
decoder_out
.
output_tokens
.
size
(
0
)
).
bool
()
if
step
==
self
.
max_iter
:
# reach last iteration, terminate
terminated
.
fill_
(
1
)
# collect finalized sentences
finalized_idxs
=
sent_idxs
[
terminated
]
finalized_tokens
=
decoder_out
.
output_tokens
[
terminated
]
finalized_scores
=
decoder_out
.
output_scores
[
terminated
]
finalized_attn
=
(
None
if
(
decoder_out
.
attn
is
None
or
decoder_out
.
attn
.
size
(
0
)
==
0
)
else
decoder_out
.
attn
[
terminated
]
)
if
self
.
retain_history
:
finalized_history_tokens
=
[
h
[
terminated
]
for
h
in
decoder_out
.
history
]
for
i
in
range
(
finalized_idxs
.
size
(
0
)):
finalized
[
finalized_idxs
[
i
]]
=
[
finalized_hypos
(
step
,
finalized_tokens
[
i
],
finalized_scores
[
i
],
None
if
finalized_attn
is
None
else
finalized_attn
[
i
],
)
]
if
self
.
retain_history
:
finalized
[
finalized_idxs
[
i
]][
0
][
"history"
]
=
[]
for
j
in
range
(
len
(
finalized_history_tokens
)):
finalized
[
finalized_idxs
[
i
]][
0
][
"history"
].
append
(
finalized_hypos
(
step
,
finalized_history_tokens
[
j
][
i
],
None
,
None
)
)
# check if all terminated
if
terminated
.
sum
()
==
terminated
.
size
(
0
):
break
# for next step
not_terminated
=
~
terminated
prev_decoder_out
=
decoder_out
.
_replace
(
output_tokens
=
decoder_out
.
output_tokens
[
not_terminated
],
output_scores
=
decoder_out
.
output_scores
[
not_terminated
],
attn
=
decoder_out
.
attn
[
not_terminated
]
if
(
decoder_out
.
attn
is
not
None
and
decoder_out
.
attn
.
size
(
0
)
>
0
)
else
None
,
history
=
[
h
[
not_terminated
]
for
h
in
decoder_out
.
history
]
if
decoder_out
.
history
is
not
None
else
None
,
)
encoder_out
=
model
.
encoder
.
reorder_encoder_out
(
encoder_out
,
not_terminated
.
nonzero
(
as_tuple
=
False
).
squeeze
()
)
sent_idxs
=
sent_idxs
[
not_terminated
]
prev_output_tokens
=
prev_decoder_out
.
output_tokens
.
clone
()
if
self
.
beam_size
>
1
:
if
reranker
is
not
None
:
finalized
=
self
.
rerank
(
reranker
,
finalized
,
[
src_tokens
,
src_lengths
],
self
.
beam_size
)
# aggregate information from length beam
finalized
=
[
finalized
[
np
.
argmax
(
[
finalized
[
self
.
beam_size
*
i
+
j
][
0
][
"score"
]
for
j
in
range
(
self
.
beam_size
)
]
)
+
self
.
beam_size
*
i
]
for
i
in
range
(
len
(
finalized
)
//
self
.
beam_size
)
]
return
finalized
def
rerank
(
self
,
reranker
,
finalized
,
encoder_input
,
beam_size
):
def
rebuild_batch
(
finalized
):
finalized_tokens
=
[
f
[
0
][
"tokens"
]
for
f
in
finalized
]
finalized_maxlen
=
max
(
f
.
size
(
0
)
for
f
in
finalized_tokens
)
final_output_tokens
=
(
finalized_tokens
[
0
]
.
new_zeros
(
len
(
finalized_tokens
),
finalized_maxlen
)
.
fill_
(
self
.
pad
)
)
for
i
,
f
in
enumerate
(
finalized_tokens
):
final_output_tokens
[
i
,
:
f
.
size
(
0
)]
=
f
return
final_output_tokens
final_output_tokens
=
rebuild_batch
(
finalized
)
final_output_tokens
[
:,
0
]
=
self
.
eos
# autoregressive model assumes starting with EOS
reranker_encoder_out
=
reranker
.
encoder
(
*
encoder_input
)
length_beam_order
=
(
utils
.
new_arange
(
final_output_tokens
,
beam_size
,
reranker_encoder_out
.
encoder_out
.
size
(
1
)
)
.
t
()
.
reshape
(
-
1
)
)
reranker_encoder_out
=
reranker
.
encoder
.
reorder_encoder_out
(
reranker_encoder_out
,
length_beam_order
)
reranking_scores
=
reranker
.
get_normalized_probs
(
reranker
.
decoder
(
final_output_tokens
[:,
:
-
1
],
reranker_encoder_out
),
True
,
None
,
)
reranking_scores
=
reranking_scores
.
gather
(
2
,
final_output_tokens
[:,
1
:,
None
])
reranking_masks
=
final_output_tokens
[:,
1
:].
ne
(
self
.
pad
)
reranking_scores
=
(
reranking_scores
[:,
:,
0
].
masked_fill_
(
~
reranking_masks
,
0
).
sum
(
1
)
)
reranking_scores
=
reranking_scores
/
reranking_masks
.
sum
(
1
).
type_as
(
reranking_scores
)
for
i
in
range
(
len
(
finalized
)):
finalized
[
i
][
0
][
"score"
]
=
reranking_scores
[
i
]
return
finalized
PyTorch/NLP/Transformer/
.gitmodules
→
PyTorch/NLP/
new-
Transformer/
fairseq/logging/__init__.py
View file @
c0f05c10
File moved
PyTorch/NLP/new-Transformer/fairseq/logging/meters.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
bisect
import
time
from
collections
import
OrderedDict
from
typing
import
Dict
,
Optional
try
:
import
torch
def
type_as
(
a
,
b
):
if
torch
.
is_tensor
(
a
)
and
torch
.
is_tensor
(
b
):
return
a
.
to
(
b
)
else
:
return
a
except
ImportError
:
torch
=
None
def
type_as
(
a
,
b
):
return
a
try
:
import
numpy
as
np
except
ImportError
:
np
=
None
class
Meter
(
object
):
"""Base class for Meters."""
def
__init__
(
self
):
pass
def
state_dict
(
self
):
return
{}
def
load_state_dict
(
self
,
state_dict
):
pass
def
reset
(
self
):
raise
NotImplementedError
@
property
def
smoothed_value
(
self
)
->
float
:
"""Smoothed value used for logging."""
raise
NotImplementedError
def
safe_round
(
number
,
ndigits
):
if
hasattr
(
number
,
"__round__"
):
return
round
(
number
,
ndigits
)
elif
torch
is
not
None
and
torch
.
is_tensor
(
number
)
and
number
.
numel
()
==
1
:
return
safe_round
(
number
.
item
(),
ndigits
)
elif
np
is
not
None
and
np
.
ndim
(
number
)
==
0
and
hasattr
(
number
,
"item"
):
return
safe_round
(
number
.
item
(),
ndigits
)
else
:
return
number
class
AverageMeter
(
Meter
):
"""Computes and stores the average and current value"""
def
__init__
(
self
,
round
:
Optional
[
int
]
=
None
):
self
.
round
=
round
self
.
reset
()
def
reset
(
self
):
self
.
val
=
None
# most recent update
self
.
sum
=
0
# sum from all updates
self
.
count
=
0
# total n from all updates
def
update
(
self
,
val
,
n
=
1
):
if
val
is
not
None
:
self
.
val
=
val
if
n
>
0
:
self
.
sum
=
type_as
(
self
.
sum
,
val
)
+
(
val
*
n
)
self
.
count
=
type_as
(
self
.
count
,
n
)
+
n
def
state_dict
(
self
):
return
{
"val"
:
self
.
val
,
"sum"
:
self
.
sum
,
"count"
:
self
.
count
,
"round"
:
self
.
round
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
val
=
state_dict
[
"val"
]
self
.
sum
=
state_dict
[
"sum"
]
self
.
count
=
state_dict
[
"count"
]
self
.
round
=
state_dict
.
get
(
"round"
,
None
)
@
property
def
avg
(
self
):
return
self
.
sum
/
self
.
count
if
self
.
count
>
0
else
self
.
val
@
property
def
smoothed_value
(
self
)
->
float
:
val
=
self
.
avg
if
self
.
round
is
not
None
and
val
is
not
None
:
val
=
safe_round
(
val
,
self
.
round
)
return
val
class
SumMeter
(
Meter
):
"""Computes and stores the sum"""
def
__init__
(
self
,
round
:
Optional
[
int
]
=
None
):
self
.
round
=
round
self
.
reset
()
def
reset
(
self
):
self
.
sum
=
0
# sum from all updates
def
update
(
self
,
val
):
if
val
is
not
None
:
self
.
sum
=
type_as
(
self
.
sum
,
val
)
+
val
def
state_dict
(
self
):
return
{
"sum"
:
self
.
sum
,
"round"
:
self
.
round
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
sum
=
state_dict
[
"sum"
]
self
.
round
=
state_dict
.
get
(
"round"
,
None
)
@
property
def
smoothed_value
(
self
)
->
float
:
val
=
self
.
sum
if
self
.
round
is
not
None
and
val
is
not
None
:
val
=
safe_round
(
val
,
self
.
round
)
return
val
class
TimeMeter
(
Meter
):
"""Computes the average occurrence of some event per second"""
def
__init__
(
self
,
init
:
int
=
0
,
n
:
int
=
0
,
round
:
Optional
[
int
]
=
None
,
):
self
.
round
=
round
self
.
reset
(
init
,
n
)
def
reset
(
self
,
init
=
0
,
n
=
0
):
self
.
init
=
init
self
.
start
=
time
.
perf_counter
()
self
.
n
=
n
self
.
i
=
0
def
update
(
self
,
val
=
1
):
self
.
n
=
type_as
(
self
.
n
,
val
)
+
val
self
.
i
+=
1
def
state_dict
(
self
):
return
{
"init"
:
self
.
elapsed_time
,
"n"
:
self
.
n
,
"round"
:
self
.
round
,
}
def
load_state_dict
(
self
,
state_dict
):
if
"start"
in
state_dict
:
# backwards compatibility for old state_dicts
self
.
reset
(
init
=
state_dict
[
"init"
])
else
:
self
.
reset
(
init
=
state_dict
[
"init"
],
n
=
state_dict
[
"n"
])
self
.
round
=
state_dict
.
get
(
"round"
,
None
)
@
property
def
avg
(
self
):
return
self
.
n
/
self
.
elapsed_time
@
property
def
elapsed_time
(
self
):
return
self
.
init
+
(
time
.
perf_counter
()
-
self
.
start
)
@
property
def
smoothed_value
(
self
)
->
float
:
val
=
self
.
avg
if
self
.
round
is
not
None
and
val
is
not
None
:
val
=
safe_round
(
val
,
self
.
round
)
return
val
class
StopwatchMeter
(
Meter
):
"""Computes the sum/avg duration of some event in seconds"""
def
__init__
(
self
,
round
:
Optional
[
int
]
=
None
):
self
.
round
=
round
self
.
sum
=
0
self
.
n
=
0
self
.
start_time
=
None
def
start
(
self
):
self
.
start_time
=
time
.
perf_counter
()
def
stop
(
self
,
n
=
1
,
prehook
=
None
):
if
self
.
start_time
is
not
None
:
if
prehook
is
not
None
:
prehook
()
delta
=
time
.
perf_counter
()
-
self
.
start_time
self
.
sum
=
self
.
sum
+
delta
self
.
n
=
type_as
(
self
.
n
,
n
)
+
n
def
reset
(
self
):
self
.
sum
=
0
# cumulative time during which stopwatch was active
self
.
n
=
0
# total n across all start/stop
self
.
start
()
def
state_dict
(
self
):
return
{
"sum"
:
self
.
sum
,
"n"
:
self
.
n
,
"round"
:
self
.
round
,
}
def
load_state_dict
(
self
,
state_dict
):
self
.
sum
=
state_dict
[
"sum"
]
self
.
n
=
state_dict
[
"n"
]
self
.
start_time
=
None
self
.
round
=
state_dict
.
get
(
"round"
,
None
)
@
property
def
avg
(
self
):
return
self
.
sum
/
self
.
n
if
self
.
n
>
0
else
self
.
sum
@
property
def
elapsed_time
(
self
):
if
self
.
start_time
is
None
:
return
0.0
return
time
.
perf_counter
()
-
self
.
start_time
@
property
def
smoothed_value
(
self
)
->
float
:
val
=
self
.
avg
if
self
.
sum
>
0
else
self
.
elapsed_time
if
self
.
round
is
not
None
and
val
is
not
None
:
val
=
safe_round
(
val
,
self
.
round
)
return
val
class
MetersDict
(
OrderedDict
):
"""A sorted dictionary of :class:`Meters`.
Meters are sorted according to a priority that is given when the
meter is first added to the dictionary.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
priorities
=
[]
def
__setitem__
(
self
,
key
,
value
):
assert
key
not
in
self
,
"MetersDict doesn't support reassignment"
priority
,
value
=
value
bisect
.
insort
(
self
.
priorities
,
(
priority
,
len
(
self
.
priorities
),
key
))
super
().
__setitem__
(
key
,
value
)
for
_
,
_
,
key
in
self
.
priorities
:
# reorder dict to match priorities
self
.
move_to_end
(
key
)
def
add_meter
(
self
,
key
,
meter
,
priority
):
self
.
__setitem__
(
key
,
(
priority
,
meter
))
def
state_dict
(
self
):
return
[
(
pri
,
key
,
self
[
key
].
__class__
.
__name__
,
self
[
key
].
state_dict
())
for
pri
,
_
,
key
in
self
.
priorities
# can't serialize DerivedMeter instances
if
not
isinstance
(
self
[
key
],
MetersDict
.
_DerivedMeter
)
]
def
load_state_dict
(
self
,
state_dict
):
self
.
clear
()
self
.
priorities
.
clear
()
for
pri
,
key
,
meter_cls
,
meter_state
in
state_dict
:
meter
=
globals
()[
meter_cls
]()
meter
.
load_state_dict
(
meter_state
)
self
.
add_meter
(
key
,
meter
,
pri
)
def
get_smoothed_value
(
self
,
key
:
str
)
->
float
:
"""Get a single smoothed value."""
meter
=
self
[
key
]
if
isinstance
(
meter
,
MetersDict
.
_DerivedMeter
):
return
meter
.
fn
(
self
)
else
:
return
meter
.
smoothed_value
def
get_smoothed_values
(
self
)
->
Dict
[
str
,
float
]:
"""Get all smoothed values."""
return
OrderedDict
(
[
(
key
,
self
.
get_smoothed_value
(
key
))
for
key
in
self
.
keys
()
if
not
key
.
startswith
(
"_"
)
]
)
def
reset
(
self
):
"""Reset Meter instances."""
for
meter
in
self
.
values
():
if
isinstance
(
meter
,
MetersDict
.
_DerivedMeter
):
continue
meter
.
reset
()
class
_DerivedMeter
(
Meter
):
"""A Meter whose values are derived from other Meters."""
def
__init__
(
self
,
fn
):
self
.
fn
=
fn
def
reset
(
self
):
pass
PyTorch/NLP/new-Transformer/fairseq/logging/metrics.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
A standalone module for aggregating metrics.
Metrics can be logged from anywhere using the `log_*` functions defined
in this module. The logged values will be aggregated dynamically based
on the aggregation context in which the logging occurs. See the
:func:`aggregate` context manager for more details.
"""
import
contextlib
import
uuid
from
collections
import
defaultdict
from
typing
import
Callable
,
List
,
Optional
from
.meters
import
*
# Aggregation contexts are considered "active" when inside the scope
# created by the :func:`aggregate` context manager.
_aggregators
=
OrderedDict
()
_active_aggregators
=
OrderedDict
()
_active_aggregators_cnt
=
defaultdict
(
lambda
:
0
)
def
reset
()
->
None
:
"""Reset all metrics aggregators."""
_aggregators
.
clear
()
_active_aggregators
.
clear
()
_active_aggregators_cnt
.
clear
()
# The "default" aggregator observes all logged values.
_aggregators
[
"default"
]
=
MetersDict
()
_active_aggregators
[
"default"
]
=
_aggregators
[
"default"
]
_active_aggregators_cnt
[
"default"
]
=
1
reset
()
@
contextlib
.
contextmanager
def
aggregate
(
name
:
Optional
[
str
]
=
None
,
new_root
:
bool
=
False
):
"""Context manager to aggregate metrics under a given name.
Aggregations can be nested. If *new_root* is ``False``, then logged
metrics will be recorded along the entire stack of nested
aggregators, including a global "default" aggregator. If *new_root*
is ``True``, then this aggregator will be the root of a new
aggregation stack, thus bypassing any parent aggregators.
Note that aggregation contexts are uniquely identified by their
*name* (e.g., train, valid). Creating a context with an existing
name will reuse the corresponding :class:`MetersDict` instance.
If no name is given, then a temporary aggregator will be created.
Usage::
with metrics.aggregate("train"):
for step, batch in enumerate(epoch):
with metrics.aggregate("train_inner") as agg:
metrics.log_scalar("loss", get_loss(batch))
if step % log_interval == 0:
print(agg.get_smoothed_value("loss"))
agg.reset()
print(metrics.get_smoothed_values("train")["loss"])
Args:
name (str): name of the aggregation. Defaults to a
random/temporary name if not given explicitly.
new_root (bool): make this aggregation the root of a new
aggregation stack.
"""
if
name
is
None
:
# generate a temporary name
name
=
str
(
uuid
.
uuid4
())
assert
name
not
in
_aggregators
agg
=
MetersDict
()
else
:
assert
name
!=
"default"
agg
=
_aggregators
.
setdefault
(
name
,
MetersDict
())
if
new_root
:
backup_aggregators
=
_active_aggregators
.
copy
()
_active_aggregators
.
clear
()
backup_aggregators_cnt
=
_active_aggregators_cnt
.
copy
()
_active_aggregators_cnt
.
clear
()
_active_aggregators
[
name
]
=
agg
_active_aggregators_cnt
[
name
]
+=
1
yield
agg
_active_aggregators_cnt
[
name
]
-=
1
if
_active_aggregators_cnt
[
name
]
==
0
and
name
in
_active_aggregators
:
del
_active_aggregators
[
name
]
if
new_root
:
_active_aggregators
.
clear
()
_active_aggregators
.
update
(
backup_aggregators
)
_active_aggregators_cnt
.
clear
()
_active_aggregators_cnt
.
update
(
backup_aggregators_cnt
)
def
get_active_aggregators
()
->
List
[
MetersDict
]:
return
list
(
_active_aggregators
.
values
())
def
log_scalar
(
key
:
str
,
value
:
float
,
weight
:
float
=
1
,
priority
:
int
=
10
,
round
:
Optional
[
int
]
=
None
,
):
"""Log a scalar value.
Args:
key (str): name of the field to log
value (float): value to log
weight (float): weight that this value contributes to the average.
A weight of 0 will always log the latest value.
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
AverageMeter
(
round
=
round
),
priority
)
agg
[
key
].
update
(
value
,
weight
)
def
log_scalar_sum
(
key
:
str
,
value
:
float
,
priority
:
int
=
10
,
round
:
Optional
[
int
]
=
None
,
):
"""Log a scalar value that is summed for reporting.
Args:
key (str): name of the field to log
value (float): value to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
SumMeter
(
round
=
round
),
priority
)
agg
[
key
].
update
(
value
)
def
log_derived
(
key
:
str
,
fn
:
Callable
[[
MetersDict
],
float
],
priority
:
int
=
20
):
"""Log a scalar value derived from other meters.
Args:
key (str): name of the field to log
fn (Callable[[MetersDict], float]): function that takes a single
argument *meters* and returns the derived value
priority (int): smaller values are logged earlier in the output
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
MetersDict
.
_DerivedMeter
(
fn
),
priority
)
def
log_speed
(
key
:
str
,
value
:
float
,
priority
:
int
=
30
,
round
:
Optional
[
int
]
=
None
,
):
"""Log the rate of some quantity per second.
Args:
key (str): name of the field to log
value (float): value to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
TimeMeter
(
round
=
round
),
priority
)
agg
[
key
].
reset
()
# reset meter on the first call
else
:
agg
[
key
].
update
(
value
)
def
log_start_time
(
key
:
str
,
priority
:
int
=
40
,
round
:
Optional
[
int
]
=
None
):
"""Log the duration of some event in seconds.
The duration will be computed once :func:`log_stop_time` is called.
Args:
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
round (Optional[int]): number of digits to round to when displaying
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
StopwatchMeter
(
round
=
round
),
priority
)
agg
[
key
].
start
()
def
log_stop_time
(
key
:
str
,
weight
:
float
=
0.0
,
prehook
=
None
):
"""Log the duration of some event in seconds.
The duration will be computed since :func:`log_start_time` was called.
Set weight > 0 to report the average time instead of the sum.
Args:
key (str): name of the field to log
weight (float): weight that this time contributes to the average
prehook (function, no arguments): will be called before the timer
is stopped. For example, use prehook=torch.cuda.synchronize to
make sure all gpu operations are done before timer is stopped.
"""
for
agg
in
get_active_aggregators
():
if
key
in
agg
:
agg
[
key
].
stop
(
weight
,
prehook
)
def
log_custom
(
new_meter_fn
:
Callable
[[],
Meter
],
key
:
str
,
*
args
,
priority
:
int
=
50
,
**
kwargs
,
):
"""Log using a custom Meter.
Any extra *args* or *kwargs* will be passed through to the Meter's
*update* method.
Args:
new_meter_fn (Callable[[], Meter]): function that returns a new
Meter instance
key (str): name of the field to log
priority (int): smaller values are logged earlier in the output
"""
for
agg
in
get_active_aggregators
():
if
key
not
in
agg
:
agg
.
add_meter
(
key
,
new_meter_fn
(),
priority
)
agg
[
key
].
update
(
*
args
,
**
kwargs
)
def
reset_meter
(
name
:
str
,
key
:
str
)
->
None
:
"""Reset Meter instance aggregated under a given *name* and *key*."""
meter
=
get_meter
(
name
,
key
)
if
meter
is
not
None
:
meter
.
reset
()
def
reset_meters
(
name
:
str
)
->
None
:
"""Reset Meter instances aggregated under a given *name*."""
meters
=
get_meters
(
name
)
if
meters
is
not
None
:
meters
.
reset
()
def
get_meter
(
name
:
str
,
key
:
str
)
->
Meter
:
"""Get a single Meter instance aggregated under *name* and *key*.
Returns:
Meter or None if no metrics have been logged under *name* and *key*.
"""
if
name
not
in
_aggregators
:
return
None
return
_aggregators
[
name
].
get
(
key
,
None
)
def
get_meters
(
name
:
str
)
->
MetersDict
:
"""Get Meter instances aggregated under a given *name*.
Returns:
MetersDict or None if no metrics have been logged under *name*.
"""
return
_aggregators
.
get
(
name
,
None
)
def
get_smoothed_value
(
name
:
str
,
key
:
str
)
->
float
:
"""Get a single smoothed value.
Raises:
KeyError: if no metrics have been logged under *name* and *key*.
"""
return
_aggregators
[
name
].
get_smoothed_value
(
key
)
def
get_smoothed_values
(
name
:
str
)
->
Dict
[
str
,
float
]:
"""Get smoothed values aggregated under a given *name*.
Raises:
KeyError: if no metrics have been logged under *name*.
"""
return
_aggregators
[
name
].
get_smoothed_values
()
def
state_dict
():
return
OrderedDict
([(
name
,
agg
.
state_dict
())
for
name
,
agg
in
_aggregators
.
items
()])
def
load_state_dict
(
state_dict
):
for
name
,
agg_state
in
state_dict
.
items
():
_aggregators
[
name
]
=
MetersDict
()
_aggregators
[
name
].
load_state_dict
(
agg_state
)
def
xla_metrics_report
():
try
:
import
torch_xla.debug.metrics
as
met
print
(
met
.
metrics_report
())
except
ImportError
:
return
PyTorch/NLP/new-Transformer/fairseq/logging/progress_bar.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapper around various loggers and progress bars (e.g., tqdm).
"""
import
atexit
import
json
import
logging
import
os
import
sys
from
collections
import
OrderedDict
from
contextlib
import
contextmanager
from
numbers
import
Number
from
typing
import
Optional
import
torch
from
.meters
import
AverageMeter
,
StopwatchMeter
,
TimeMeter
logger
=
logging
.
getLogger
(
__name__
)
def
progress_bar
(
iterator
,
log_format
:
Optional
[
str
]
=
None
,
log_interval
:
int
=
100
,
log_file
:
Optional
[
str
]
=
None
,
epoch
:
Optional
[
int
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
aim_repo
:
Optional
[
str
]
=
None
,
aim_run_hash
:
Optional
[
str
]
=
None
,
aim_param_checkpoint_dir
:
Optional
[
str
]
=
None
,
tensorboard_logdir
:
Optional
[
str
]
=
None
,
default_log_format
:
str
=
"tqdm"
,
wandb_project
:
Optional
[
str
]
=
None
,
wandb_run_name
:
Optional
[
str
]
=
None
,
azureml_logging
:
Optional
[
bool
]
=
False
,
):
if
log_format
is
None
:
log_format
=
default_log_format
if
log_file
is
not
None
:
handler
=
logging
.
FileHandler
(
filename
=
log_file
)
logger
.
addHandler
(
handler
)
if
log_format
==
"tqdm"
and
not
sys
.
stderr
.
isatty
():
log_format
=
"simple"
if
log_format
==
"json"
:
bar
=
JsonProgressBar
(
iterator
,
epoch
,
prefix
,
log_interval
)
elif
log_format
==
"none"
:
bar
=
NoopProgressBar
(
iterator
,
epoch
,
prefix
)
elif
log_format
==
"simple"
:
bar
=
SimpleProgressBar
(
iterator
,
epoch
,
prefix
,
log_interval
)
elif
log_format
==
"tqdm"
:
bar
=
TqdmProgressBar
(
iterator
,
epoch
,
prefix
)
else
:
raise
ValueError
(
"Unknown log format: {}"
.
format
(
log_format
))
if
aim_repo
:
bar
=
AimProgressBarWrapper
(
bar
,
aim_repo
=
aim_repo
,
aim_run_hash
=
aim_run_hash
,
aim_param_checkpoint_dir
=
aim_param_checkpoint_dir
,
)
if
tensorboard_logdir
:
try
:
# [FB only] custom wrapper for TensorBoard
import
palaas
# noqa
from
.fb_tbmf_wrapper
import
FbTbmfWrapper
bar
=
FbTbmfWrapper
(
bar
,
log_interval
)
except
ImportError
:
bar
=
TensorboardProgressBarWrapper
(
bar
,
tensorboard_logdir
)
if
wandb_project
:
bar
=
WandBProgressBarWrapper
(
bar
,
wandb_project
,
run_name
=
wandb_run_name
)
if
azureml_logging
:
bar
=
AzureMLProgressBarWrapper
(
bar
)
return
bar
def
build_progress_bar
(
args
,
iterator
,
epoch
:
Optional
[
int
]
=
None
,
prefix
:
Optional
[
str
]
=
None
,
default
:
str
=
"tqdm"
,
no_progress_bar
:
str
=
"none"
,
):
"""Legacy wrapper that takes an argparse.Namespace."""
if
getattr
(
args
,
"no_progress_bar"
,
False
):
default
=
no_progress_bar
if
getattr
(
args
,
"distributed_rank"
,
0
)
==
0
:
tensorboard_logdir
=
getattr
(
args
,
"tensorboard_logdir"
,
None
)
else
:
tensorboard_logdir
=
None
return
progress_bar
(
iterator
,
log_format
=
args
.
log_format
,
log_interval
=
args
.
log_interval
,
epoch
=
epoch
,
prefix
=
prefix
,
tensorboard_logdir
=
tensorboard_logdir
,
default_log_format
=
default
,
)
def
format_stat
(
stat
):
if
isinstance
(
stat
,
Number
):
stat
=
"{:g}"
.
format
(
stat
)
elif
isinstance
(
stat
,
AverageMeter
):
stat
=
"{:.3f}"
.
format
(
stat
.
avg
)
elif
isinstance
(
stat
,
TimeMeter
):
stat
=
"{:g}"
.
format
(
round
(
stat
.
avg
))
elif
isinstance
(
stat
,
StopwatchMeter
):
stat
=
"{:g}"
.
format
(
round
(
stat
.
sum
))
elif
torch
.
is_tensor
(
stat
):
stat
=
stat
.
tolist
()
return
stat
class
BaseProgressBar
(
object
):
"""Abstract class for progress bars."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
self
.
iterable
=
iterable
self
.
n
=
getattr
(
iterable
,
"n"
,
0
)
self
.
epoch
=
epoch
self
.
prefix
=
""
if
epoch
is
not
None
:
self
.
prefix
+=
"epoch {:03d}"
.
format
(
epoch
)
if
prefix
is
not
None
:
self
.
prefix
+=
(
" | "
if
self
.
prefix
!=
""
else
""
)
+
prefix
def
__len__
(
self
):
return
len
(
self
.
iterable
)
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
exc
):
return
False
def
__iter__
(
self
):
raise
NotImplementedError
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
raise
NotImplementedError
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
raise
NotImplementedError
def
update_config
(
self
,
config
):
"""Log latest configuration."""
pass
def
_str_commas
(
self
,
stats
):
return
", "
.
join
(
key
+
"="
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_str_pipes
(
self
,
stats
):
return
" | "
.
join
(
key
+
" "
+
stats
[
key
].
strip
()
for
key
in
stats
.
keys
())
def
_format_stats
(
self
,
stats
):
postfix
=
OrderedDict
(
stats
)
# Preprocess stats according to datatype
for
key
in
postfix
.
keys
():
postfix
[
key
]
=
str
(
format_stat
(
postfix
[
key
]))
return
postfix
@
contextmanager
def
rename_logger
(
logger
,
new_name
):
old_name
=
logger
.
name
if
new_name
is
not
None
:
logger
.
name
=
new_name
yield
logger
logger
.
name
=
old_name
class
JsonProgressBar
(
BaseProgressBar
):
"""Log output in JSON format."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
i
=
None
self
.
size
=
None
def
__iter__
(
self
):
self
.
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
,
start
=
self
.
n
):
self
.
i
=
i
yield
obj
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
step
=
step
or
self
.
i
or
0
if
step
>
0
and
self
.
log_interval
is
not
None
and
step
%
self
.
log_interval
==
0
:
update
=
(
self
.
epoch
-
1
+
(
self
.
i
+
1
)
/
float
(
self
.
size
)
if
self
.
epoch
is
not
None
else
None
)
stats
=
self
.
_format_stats
(
stats
,
epoch
=
self
.
epoch
,
update
=
update
)
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
json
.
dumps
(
stats
))
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
self
.
stats
=
stats
if
tag
is
not
None
:
self
.
stats
=
OrderedDict
(
[(
tag
+
"_"
+
k
,
v
)
for
k
,
v
in
self
.
stats
.
items
()]
)
stats
=
self
.
_format_stats
(
self
.
stats
,
epoch
=
self
.
epoch
)
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
json
.
dumps
(
stats
))
def
_format_stats
(
self
,
stats
,
epoch
=
None
,
update
=
None
):
postfix
=
OrderedDict
()
if
epoch
is
not
None
:
postfix
[
"epoch"
]
=
epoch
if
update
is
not
None
:
postfix
[
"update"
]
=
round
(
update
,
3
)
# Preprocess stats according to datatype
for
key
in
stats
.
keys
():
postfix
[
key
]
=
format_stat
(
stats
[
key
])
return
postfix
class
NoopProgressBar
(
BaseProgressBar
):
"""No logging."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
def
__iter__
(
self
):
for
obj
in
self
.
iterable
:
yield
obj
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
pass
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
pass
class
SimpleProgressBar
(
BaseProgressBar
):
"""A minimal logger for non-TTY environments."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
,
log_interval
=
1000
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
self
.
log_interval
=
log_interval
self
.
i
=
None
self
.
size
=
None
def
__iter__
(
self
):
self
.
size
=
len
(
self
.
iterable
)
for
i
,
obj
in
enumerate
(
self
.
iterable
,
start
=
self
.
n
):
self
.
i
=
i
yield
obj
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
step
=
step
or
self
.
i
or
0
if
step
>
0
and
self
.
log_interval
is
not
None
and
step
%
self
.
log_interval
==
0
:
stats
=
self
.
_format_stats
(
stats
)
postfix
=
self
.
_str_commas
(
stats
)
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
"{}: {:5d} / {:d} {}"
.
format
(
self
.
prefix
,
self
.
i
+
1
,
self
.
size
,
postfix
)
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
"{} | {}"
.
format
(
self
.
prefix
,
postfix
))
class
TqdmProgressBar
(
BaseProgressBar
):
"""Log to tqdm."""
def
__init__
(
self
,
iterable
,
epoch
=
None
,
prefix
=
None
):
super
().
__init__
(
iterable
,
epoch
,
prefix
)
from
tqdm
import
tqdm
self
.
tqdm
=
tqdm
(
iterable
,
self
.
prefix
,
leave
=
False
,
disable
=
(
logger
.
getEffectiveLevel
()
>
logging
.
INFO
),
)
def
__iter__
(
self
):
return
iter
(
self
.
tqdm
)
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats according to log_interval."""
self
.
tqdm
.
set_postfix
(
self
.
_format_stats
(
stats
),
refresh
=
False
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
postfix
=
self
.
_str_pipes
(
self
.
_format_stats
(
stats
))
with
rename_logger
(
logger
,
tag
):
logger
.
info
(
"{} | {}"
.
format
(
self
.
prefix
,
postfix
))
try
:
import
functools
from
aim
import
Repo
as
AimRepo
@
functools
.
lru_cache
()
def
get_aim_run
(
repo
,
run_hash
):
from
aim
import
Run
return
Run
(
run_hash
=
run_hash
,
repo
=
repo
)
except
ImportError
:
get_aim_run
=
None
AimRepo
=
None
class
AimProgressBarWrapper
(
BaseProgressBar
):
"""Log to Aim."""
def
__init__
(
self
,
wrapped_bar
,
aim_repo
,
aim_run_hash
,
aim_param_checkpoint_dir
):
self
.
wrapped_bar
=
wrapped_bar
if
get_aim_run
is
None
:
self
.
run
=
None
logger
.
warning
(
"Aim not found, please install with: pip install aim"
)
else
:
logger
.
info
(
f
"Storing logs at Aim repo:
{
aim_repo
}
"
)
if
not
aim_run_hash
:
# Find run based on save_dir parameter
query
=
f
"run.checkpoint.save_dir == '
{
aim_param_checkpoint_dir
}
'"
try
:
runs_generator
=
AimRepo
(
aim_repo
).
query_runs
(
query
)
run
=
next
(
runs_generator
.
iter_runs
())
aim_run_hash
=
run
.
run
.
hash
except
Exception
:
pass
if
aim_run_hash
:
logger
.
info
(
f
"Appending to run:
{
aim_run_hash
}
"
)
self
.
run
=
get_aim_run
(
aim_repo
,
aim_run_hash
)
def
__iter__
(
self
):
return
iter
(
self
.
wrapped_bar
)
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats to Aim."""
self
.
_log_to_aim
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
log
(
stats
,
tag
=
tag
,
step
=
step
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
self
.
_log_to_aim
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
print
(
stats
,
tag
=
tag
,
step
=
step
)
def
update_config
(
self
,
config
):
"""Log latest configuration."""
if
self
.
run
is
not
None
:
for
key
in
config
:
self
.
run
.
set
(
key
,
config
[
key
],
strict
=
False
)
self
.
wrapped_bar
.
update_config
(
config
)
def
_log_to_aim
(
self
,
stats
,
tag
=
None
,
step
=
None
):
if
self
.
run
is
None
:
return
if
step
is
None
:
step
=
stats
[
"num_updates"
]
if
"train"
in
tag
:
context
=
{
"tag"
:
tag
,
"subset"
:
"train"
}
elif
"val"
in
tag
:
context
=
{
"tag"
:
tag
,
"subset"
:
"val"
}
else
:
context
=
{
"tag"
:
tag
}
for
key
in
stats
.
keys
()
-
{
"num_updates"
}:
self
.
run
.
track
(
stats
[
key
],
name
=
key
,
step
=
step
,
context
=
context
)
try
:
_tensorboard_writers
=
{}
from
torch.utils.tensorboard
import
SummaryWriter
except
ImportError
:
try
:
from
tensorboardX
import
SummaryWriter
except
ImportError
:
SummaryWriter
=
None
def
_close_writers
():
for
w
in
_tensorboard_writers
.
values
():
w
.
close
()
atexit
.
register
(
_close_writers
)
class
TensorboardProgressBarWrapper
(
BaseProgressBar
):
"""Log to tensorboard."""
def
__init__
(
self
,
wrapped_bar
,
tensorboard_logdir
):
self
.
wrapped_bar
=
wrapped_bar
self
.
tensorboard_logdir
=
tensorboard_logdir
if
SummaryWriter
is
None
:
logger
.
warning
(
"tensorboard not found, please install with: pip install tensorboard"
)
def
_writer
(
self
,
key
):
if
SummaryWriter
is
None
:
return
None
_writers
=
_tensorboard_writers
if
key
not
in
_writers
:
_writers
[
key
]
=
SummaryWriter
(
os
.
path
.
join
(
self
.
tensorboard_logdir
,
key
))
_writers
[
key
].
add_text
(
"sys.argv"
,
" "
.
join
(
sys
.
argv
))
return
_writers
[
key
]
def
__iter__
(
self
):
return
iter
(
self
.
wrapped_bar
)
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats to tensorboard."""
self
.
_log_to_tensorboard
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
log
(
stats
,
tag
=
tag
,
step
=
step
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
self
.
_log_to_tensorboard
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
print
(
stats
,
tag
=
tag
,
step
=
step
)
def
update_config
(
self
,
config
):
"""Log latest configuration."""
# TODO add hparams to Tensorboard
self
.
wrapped_bar
.
update_config
(
config
)
def
_log_to_tensorboard
(
self
,
stats
,
tag
=
None
,
step
=
None
):
writer
=
self
.
_writer
(
tag
or
""
)
if
writer
is
None
:
return
if
step
is
None
:
step
=
stats
[
"num_updates"
]
for
key
in
stats
.
keys
()
-
{
"num_updates"
}:
if
isinstance
(
stats
[
key
],
AverageMeter
):
writer
.
add_scalar
(
key
,
stats
[
key
].
val
,
step
)
elif
isinstance
(
stats
[
key
],
Number
):
writer
.
add_scalar
(
key
,
stats
[
key
],
step
)
elif
torch
.
is_tensor
(
stats
[
key
])
and
stats
[
key
].
numel
()
==
1
:
writer
.
add_scalar
(
key
,
stats
[
key
].
item
(),
step
)
writer
.
flush
()
try
:
import
wandb
except
ImportError
:
wandb
=
None
class
WandBProgressBarWrapper
(
BaseProgressBar
):
"""Log to Weights & Biases."""
def
__init__
(
self
,
wrapped_bar
,
wandb_project
,
run_name
=
None
):
self
.
wrapped_bar
=
wrapped_bar
if
wandb
is
None
:
logger
.
warning
(
"wandb not found, pip install wandb"
)
return
# reinit=False to ensure if wandb.init() is called multiple times
# within one process it still references the same run
wandb
.
init
(
project
=
wandb_project
,
reinit
=
False
,
name
=
run_name
)
def
__iter__
(
self
):
return
iter
(
self
.
wrapped_bar
)
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats to tensorboard."""
self
.
_log_to_wandb
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
log
(
stats
,
tag
=
tag
,
step
=
step
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats."""
self
.
_log_to_wandb
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
print
(
stats
,
tag
=
tag
,
step
=
step
)
def
update_config
(
self
,
config
):
"""Log latest configuration."""
if
wandb
is
not
None
:
wandb
.
config
.
update
(
config
)
self
.
wrapped_bar
.
update_config
(
config
)
def
_log_to_wandb
(
self
,
stats
,
tag
=
None
,
step
=
None
):
if
wandb
is
None
:
return
if
step
is
None
:
step
=
stats
[
"num_updates"
]
prefix
=
""
if
tag
is
None
else
tag
+
"/"
for
key
in
stats
.
keys
()
-
{
"num_updates"
}:
if
isinstance
(
stats
[
key
],
AverageMeter
):
wandb
.
log
({
prefix
+
key
:
stats
[
key
].
val
},
step
=
step
)
elif
isinstance
(
stats
[
key
],
Number
):
wandb
.
log
({
prefix
+
key
:
stats
[
key
]},
step
=
step
)
try
:
from
azureml.core
import
Run
except
ImportError
:
Run
=
None
class
AzureMLProgressBarWrapper
(
BaseProgressBar
):
"""Log to Azure ML"""
def
__init__
(
self
,
wrapped_bar
):
self
.
wrapped_bar
=
wrapped_bar
if
Run
is
None
:
logger
.
warning
(
"azureml.core not found, pip install azureml-core"
)
return
self
.
run
=
Run
.
get_context
()
def
__exit__
(
self
,
*
exc
):
if
Run
is
not
None
:
self
.
run
.
complete
()
return
False
def
__iter__
(
self
):
return
iter
(
self
.
wrapped_bar
)
def
log
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Log intermediate stats to AzureML"""
self
.
_log_to_azureml
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
log
(
stats
,
tag
=
tag
,
step
=
step
)
def
print
(
self
,
stats
,
tag
=
None
,
step
=
None
):
"""Print end-of-epoch stats"""
self
.
_log_to_azureml
(
stats
,
tag
,
step
)
self
.
wrapped_bar
.
print
(
stats
,
tag
=
tag
,
step
=
step
)
def
update_config
(
self
,
config
):
"""Log latest configuration."""
self
.
wrapped_bar
.
update_config
(
config
)
def
_log_to_azureml
(
self
,
stats
,
tag
=
None
,
step
=
None
):
if
Run
is
None
:
return
if
step
is
None
:
step
=
stats
[
"num_updates"
]
prefix
=
""
if
tag
is
None
else
tag
+
"/"
for
key
in
stats
.
keys
()
-
{
"num_updates"
}:
name
=
prefix
+
key
if
isinstance
(
stats
[
key
],
AverageMeter
):
self
.
run
.
log_row
(
name
=
name
,
**
{
"step"
:
step
,
key
:
stats
[
key
].
val
})
elif
isinstance
(
stats
[
key
],
Number
):
self
.
run
.
log_row
(
name
=
name
,
**
{
"step"
:
step
,
key
:
stats
[
key
]})
PyTorch/NLP/new-Transformer/fairseq/model_parallel/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.
import
criterions
,
models
,
modules
# noqa
PyTorch/NLP/new-Transformer/fairseq/model_parallel/criterions/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
importlib
import
os
# automatically import any Python files in the criterions/ directory
for
file
in
sorted
(
os
.
listdir
(
os
.
path
.
dirname
(
__file__
))):
if
file
.
endswith
(
".py"
)
and
not
file
.
startswith
(
"_"
):
module
=
file
[:
file
.
find
(
".py"
)]
importlib
.
import_module
(
"fairseq.model_parallel.criterions."
+
module
)
PyTorch/NLP/new-Transformer/fairseq/model_parallel/criterions/vocab_parallel_cross_entropy.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
fairseq
import
metrics
,
utils
from
fairseq.criterions
import
FairseqCriterion
,
register_criterion
try
:
from
fairseq.model_parallel.megatron.mpu.cross_entropy
import
(
vocab_parallel_cross_entropy
,
)
has_megatron_submodule
=
True
except
(
ImportError
,
ModuleNotFoundError
):
has_megatron_submodule
=
False
@
register_criterion
(
"vocab_parallel_cross_entropy"
)
class
VocabParallelCrossEntropyCriterion
(
FairseqCriterion
):
def
__init__
(
self
,
task
,
sentence_avg
):
super
().
__init__
(
task
)
self
.
sentence_avg
=
sentence_avg
if
not
has_megatron_submodule
:
raise
ImportError
(
"
\n\n
Please install the megatron submodule:"
"
\n\n
git submodule update --init "
"fairseq/model_parallel/megatron"
)
def
forward
(
self
,
model
,
sample
,
reduce
=
True
):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output
=
model
(
**
sample
[
"net_input"
])
target
=
sample
[
"target"
]
loss
=
vocab_parallel_cross_entropy
(
net_output
[
0
].
float
(),
target
)
loss
=
(
loss
*
(
target
!=
self
.
padding_idx
)).
sum
()
sample_size
=
(
sample
[
"target"
].
size
(
0
)
if
self
.
sentence_avg
else
sample
[
"ntokens"
]
)
logging_output
=
{
"loss"
:
utils
.
item
(
loss
.
data
)
if
reduce
else
loss
.
data
,
"ntokens"
:
sample
[
"ntokens"
],
"nsentences"
:
sample
[
"target"
].
size
(
0
),
"sample_size"
:
sample_size
,
}
return
loss
,
sample_size
,
logging_output
@
staticmethod
def
reduce_metrics
(
logging_outputs
)
->
None
:
"""Aggregate logging outputs from data parallel training."""
loss_sum
=
sum
(
log
.
get
(
"loss"
,
0
)
for
log
in
logging_outputs
)
ntokens
=
sum
(
log
.
get
(
"ntokens"
,
0
)
for
log
in
logging_outputs
)
sample_size
=
sum
(
log
.
get
(
"sample_size"
,
0
)
for
log
in
logging_outputs
)
metrics
.
log_scalar
(
"loss"
,
loss_sum
/
sample_size
/
math
.
log
(
2
),
sample_size
,
round
=
3
)
if
sample_size
!=
ntokens
:
metrics
.
log_scalar
(
"nll_loss"
,
loss_sum
/
ntokens
/
math
.
log
(
2
),
ntokens
,
round
=
3
)
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"nll_loss"
].
avg
)
)
else
:
metrics
.
log_derived
(
"ppl"
,
lambda
meters
:
utils
.
get_perplexity
(
meters
[
"loss"
].
avg
)
)
@
staticmethod
def
logging_outputs_can_be_summed
()
->
bool
:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return
True
PyTorch/NLP/new-Transformer/fairseq/model_parallel/megatron_trainer.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a network across multiple GPUs.
"""
from
fairseq.dataclass.configs
import
FairseqConfig
from
fairseq.distributed
import
utils
as
distributed_utils
from
fairseq.trainer
import
Trainer
try
:
from
fairseq.model_parallel.megatron.mpu
import
(
get_data_parallel_rank
,
get_data_parallel_world_size
,
get_model_parallel_src_rank
,
get_cuda_rng_tracker
,
)
has_megatron_submodule
=
True
except
(
ImportError
,
ModuleNotFoundError
):
has_megatron_submodule
=
False
class
MegatronTrainer
(
Trainer
):
"""Main class for model parallel with data parallel training."""
def
__init__
(
self
,
cfg
:
FairseqConfig
,
task
,
model
,
criterion
,
**
kwargs
):
if
not
has_megatron_submodule
:
raise
ImportError
(
"
\n\n
Please install the megatron submodule:"
"
\n\n
git submodule update --init "
"fairseq/model_parallel/megatron"
)
super
().
__init__
(
cfg
,
task
,
model
,
criterion
,
**
kwargs
)
def
clip_grad_norm
(
self
,
clip_norm
):
def
_aggregate_model_parallel_grad_norm
(
total_norm
):
total_norm
=
total_norm
**
2
distributed_utils
.
all_reduce
(
total_norm
,
group
=
distributed_utils
.
get_model_parallel_group
()
)
total_norm
=
total_norm
**
0.5
return
total_norm
return
self
.
optimizer
.
clip_grad_norm
(
clip_norm
,
aggregate_norm_fn
=
_aggregate_model_parallel_grad_norm
,
)
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
extra_state
[
"rng_tracker_states"
]
=
get_cuda_rng_tracker
().
get_states
()
super
().
save_checkpoint
(
filename
,
extra_state
)
def
load_checkpoint
(
self
,
filename
,
reset_optimizer
=
False
,
reset_lr_scheduler
=
False
,
optimizer_overrides
=
None
,
reset_meters
=
False
,
):
extra_state
=
super
().
load_checkpoint
(
filename
,
reset_optimizer
=
reset_optimizer
,
reset_lr_scheduler
=
reset_lr_scheduler
,
optimizer_overrides
=
optimizer_overrides
,
reset_meters
=
reset_meters
,
)
if
extra_state
is
not
None
and
"rng_tracker_states"
in
extra_state
:
get_cuda_rng_tracker
().
set_states
(
extra_state
[
"rng_tracker_states"
])
return
extra_state
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
importlib
import
os
# automatically import any Python files in the models/ directory
models_dir
=
os
.
path
.
dirname
(
__file__
)
for
file
in
os
.
listdir
(
models_dir
):
path
=
os
.
path
.
join
(
models_dir
,
file
)
if
(
not
file
.
startswith
(
"_"
)
and
not
file
.
startswith
(
"."
)
and
(
file
.
endswith
(
".py"
)
or
os
.
path
.
isdir
(
path
))
):
model_name
=
file
[:
file
.
find
(
".py"
)]
if
file
.
endswith
(
".py"
)
else
file
module
=
importlib
.
import_module
(
"fairseq.model_parallel.models."
+
model_name
)
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/pipeline_parallel_transformer/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.model
import
*
# noqa
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/pipeline_parallel_transformer/layers.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
collections
import
namedtuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
options
,
utils
from
fairseq.modules
import
(
AdaptiveSoftmax
,
LayerNorm
,
MultiheadAttention
,
PositionalEmbedding
,
)
EncoderOut
=
namedtuple
(
"TransformerEncoderOut"
,
[
"encoder_out"
,
# T x B x C
"encoder_padding_mask"
,
# B x T
"encoder_embedding"
,
# B x T x C
"encoder_states"
,
# List[T x B x C]
],
)
class
TransformerEncoderEmbedding
(
nn
.
Module
):
"""Encoder Embedding + Positional Embedding"""
def
__init__
(
self
,
args
,
embed_tokens
):
super
().
__init__
()
self
.
dropout
=
args
.
dropout
self
.
max_source_positions
=
args
.
max_source_positions
self
.
embed_tokens
=
embed_tokens
if
isinstance
(
embed_tokens
,
nn
.
ModuleList
):
self
.
padding_idx
=
embed_tokens
[
0
].
padding_idx
embed_dim
=
sum
(
e
.
embedding_dim
for
e
in
embed_tokens
)
else
:
self
.
padding_idx
=
embed_tokens
.
padding_idx
embed_dim
=
embed_tokens
.
embedding_dim
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
(
PositionalEmbedding
(
args
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
,
learned
=
args
.
encoder_learned_pos
,
)
if
not
args
.
no_token_positional_embeddings
else
None
)
if
getattr
(
args
,
"layernorm_embedding"
,
False
):
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
else
:
self
.
layernorm_embedding
=
None
def
forward
(
self
,
input
):
# embed tokens and positions
src_tokens
=
input
[
0
]
prev_output_tokens
=
input
[
2
]
if
isinstance
(
self
.
embed_tokens
,
nn
.
ModuleList
):
x_embed_list
=
[]
for
embed_tokens_part
in
self
.
embed_tokens
:
x_embed_list
.
append
(
embed_tokens_part
(
src_tokens
))
embedded
=
torch
.
cat
(
x_embed_list
,
dim
=-
1
)
else
:
embedded
=
self
.
embed_tokens
(
src_tokens
)
x
=
embed
=
self
.
embed_scale
*
embedded
if
self
.
embed_positions
is
not
None
:
x
=
embed
+
self
.
embed_positions
(
src_tokens
)
if
self
.
layernorm_embedding
:
x
=
self
.
layernorm_embedding
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
return
(
x
,
encoder_padding_mask
,
prev_output_tokens
)
class
TransformerEncoderLayerNorm
(
nn
.
Module
):
"""
Layer norm at the the end of all encoder layers if
args.encoder_enormalize_before = True
"""
def
__init__
(
self
,
args
,
embed_dim
):
super
().
__init__
()
if
args
.
encoder_normalize_before
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
else
:
self
.
layer_norm
=
None
def
forward
(
self
,
input
):
x
=
input
[
0
]
encoder_padding_mask
=
input
[
1
]
prev_output_tokens
=
input
[
2
]
if
self
.
layer_norm
:
x
=
self
.
layer_norm
(
x
)
# keeping track of the incremental_state is not supported yet
return
(
x
,
encoder_padding_mask
,
prev_output_tokens
)
class
TransformerDecoderEmbedding
(
nn
.
Module
):
"""Decoder Embedding + Positional Embedding"""
def
__init__
(
self
,
args
,
embed_tokens
):
super
().
__init__
()
self
.
dropout
=
args
.
dropout
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
input_embed_dim
=
(
sum
(
e
.
embedding_dim
for
e
in
embed_tokens
)
if
isinstance
(
embed_tokens
,
nn
.
ModuleList
)
else
embed_tokens
.
embedding_dim
)
embed_dim
=
args
.
decoder_embed_dim
self
.
output_embed_dim
=
args
.
decoder_output_dim
padding_idx
=
(
embed_tokens
[
0
].
padding_idx
if
isinstance
(
embed_tokens
,
nn
.
ModuleList
)
else
embed_tokens
.
padding_idx
)
self
.
max_target_positions
=
args
.
max_target_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
# todo: try with input_embed_dim
self
.
project_in_dim
=
(
Linear
(
input_embed_dim
,
embed_dim
,
bias
=
False
)
if
embed_dim
!=
input_embed_dim
else
None
)
self
.
embed_positions
=
(
PositionalEmbedding
(
args
.
max_target_positions
,
embed_dim
,
padding_idx
,
learned
=
args
.
decoder_learned_pos
,
)
if
not
args
.
no_token_positional_embeddings
else
None
)
def
forward
(
self
,
input
):
mt_task
=
False
if
isinstance
(
input
,
tuple
):
if
len
(
input
)
==
3
:
encoder_out
=
input
[
0
]
encoder_padding_mask
=
input
[
1
]
prev_output_tokens
=
input
[
2
]
incremental_state
=
None
# Hardcoding to avoid passing of None objects
mt_task
=
True
else
:
# HACK for now, need to fix (TODO sidgoyal)
prev_output_tokens
=
input
[
0
]
# discard "src_lengths"
encoder_out
=
None
encoder_padding_mask
=
None
incremental_state
=
None
else
:
prev_output_tokens
=
input
encoder_out
=
None
encoder_padding_mask
=
None
incremental_state
=
None
positions
=
(
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
=
incremental_state
,
)
if
self
.
embed_positions
is
not
None
else
None
)
if
incremental_state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
if
positions
is
not
None
:
positions
=
positions
[:,
-
1
:]
# embed tokens and positions
if
isinstance
(
self
.
embed_tokens
,
nn
.
ModuleList
):
x_embed_list
=
[]
for
embed_tokens_part
in
self
.
embed_tokens
:
x_embed_list
.
append
(
embed_tokens_part
(
prev_output_tokens
))
x
=
self
.
embed_scale
*
torch
.
cat
(
x_embed_list
,
dim
=-
1
)
else
:
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
self
.
project_in_dim
is
not
None
:
x
=
self
.
project_in_dim
(
x
)
if
positions
is
not
None
:
x
+=
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
if
mt_task
:
return
(
x
,
encoder_out
,
encoder_padding_mask
)
return
x
class
TransformerDecoderOutputLayer
(
nn
.
Module
):
def
__init__
(
self
,
args
,
embed_tokens
,
dictionary
):
super
().
__init__
()
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
self
.
embed_tokens
=
embed_tokens
self
.
output_embed_dim
=
args
.
decoder_output_dim
embed_dim
=
args
.
decoder_embed_dim
self
.
project_out_dim
=
(
Linear
(
embed_dim
,
self
.
output_embed_dim
,
bias
=
False
)
if
embed_dim
!=
self
.
output_embed_dim
and
not
args
.
tie_adaptive_weights
else
None
)
self
.
adaptive_softmax
=
None
if
args
.
adaptive_softmax_cutoff
is
not
None
:
assert
not
isinstance
(
embed_tokens
,
nn
.
ModuleList
)
self
.
adaptive_softmax
=
AdaptiveSoftmax
(
len
(
dictionary
),
self
.
output_embed_dim
,
options
.
eval_str_list
(
args
.
adaptive_softmax_cutoff
,
type
=
int
),
dropout
=
args
.
adaptive_softmax_dropout
,
adaptive_inputs
=
embed_tokens
if
args
.
tie_adaptive_weights
else
None
,
factor
=
args
.
adaptive_softmax_factor
,
tie_proj
=
args
.
tie_adaptive_proj
,
)
elif
not
self
.
share_input_output_embed
:
self
.
embed_tokens
=
nn
.
Parameter
(
torch
.
Tensor
(
len
(
dictionary
),
self
.
output_embed_dim
)
)
nn
.
init
.
normal_
(
self
.
embed_tokens
,
mean
=
0
,
std
=
self
.
output_embed_dim
**-
0.5
)
if
args
.
decoder_normalize_before
and
not
getattr
(
args
,
"no_decoder_final_norm"
,
False
):
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
else
:
self
.
layer_norm
=
None
def
forward
(
self
,
input
,
apply_final_proj
=
True
):
if
isinstance
(
input
,
tuple
):
x
=
input
[
0
]
else
:
x
=
input
if
self
.
layer_norm
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
project_out_dim
is
not
None
:
x
=
self
.
project_out_dim
(
x
)
if
apply_final_proj
:
x
=
self
.
output_layer
(
x
)
return
x
def
output_layer
(
self
,
features
,
**
kwargs
):
"""Project features to the vocabulary size."""
if
self
.
adaptive_softmax
is
None
:
# project back to size of vocabulary
if
self
.
share_input_output_embed
:
if
isinstance
(
self
.
embed_tokens
,
nn
.
ModuleList
):
output
=
None
for
i
,
emb
in
enumerate
(
self
.
embed_tokens
):
sidx
=
i
*
emb
.
embedding_dim
eidx
=
(
i
+
1
)
*
emb
.
embedding_dim
if
output
is
None
:
output
=
F
.
linear
(
features
[:,
:,
sidx
:
eidx
],
emb
.
weight
)
else
:
output
+=
F
.
linear
(
features
[:,
:,
sidx
:
eidx
],
emb
.
weight
)
return
output
else
:
return
F
.
linear
(
features
,
self
.
embed_tokens
.
weight
)
else
:
return
F
.
linear
(
features
,
self
.
embed_tokens
)
else
:
return
features
class
TransformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.encoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
self
.
dropout
=
args
.
dropout
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
)
self
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
0
)
if
self
.
activation_dropout
==
0
:
# for backwards compatibility with models that use args.relu_dropout
self
.
activation_dropout
=
getattr
(
args
,
"relu_dropout"
,
0
)
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""
Rename layer norm states from `...layer_norms.0.weight` to
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to
`...final_layer_norm.weight`
"""
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"final_layer_norm"
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layer_norms.{}.{}"
.
format
(
name
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.{}.{}"
.
format
(
name
,
new
,
m
)]
=
state_dict
[
k
]
del
state_dict
[
k
]
def
forward
(
self
,
input
):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
input[2] (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing)
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
x
=
input
[
0
]
encoder_padding_mask
=
input
[
1
]
prev_output_tokens
=
input
[
2
]
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
self_attn_layer_norm
,
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
self_attn_layer_norm
,
x
,
after
=
True
)
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
final_layer_norm
,
x
,
before
=
True
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
F
.
dropout
(
x
,
p
=
self
.
activation_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
final_layer_norm
,
x
,
after
=
True
)
return
(
x
,
encoder_padding_mask
,
prev_output_tokens
)
def
maybe_layer_norm
(
self
,
layer_norm
,
x
,
before
=
False
,
after
=
False
):
assert
before
^
after
if
after
^
self
.
normalize_before
:
return
layer_norm
(
x
)
else
:
return
x
class
TransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*args.decoder_normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
args
,
no_encoder_attn
=
False
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
super
().
__init__
()
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
embed_dim
=
self
.
embed_dim
,
num_heads
=
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
add_bias_kv
=
add_bias_kv
,
add_zero_attn
=
add_zero_attn
,
self_attention
=
True
,
)
self
.
dropout
=
args
.
dropout
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
)
self
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
0
)
if
self
.
activation_dropout
==
0
:
# for backwards compatibility with models that use args.relu_dropout
self
.
activation_dropout
=
getattr
(
args
,
"relu_dropout"
,
0
)
self
.
normalize_before
=
args
.
decoder_normalize_before
# use layerNorm rather than FusedLayerNorm for exporting.
# char_inputs can be used to determint this.
# TODO remove this once we update apex with the fix
export
=
getattr
(
args
,
"char_inputs"
,
False
)
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
if
no_encoder_attn
:
self
.
encoder_attn
=
None
self
.
encoder_attn_layer_norm
=
None
else
:
self
.
encoder_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
kdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
vdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
dropout
=
args
.
attention_dropout
,
encoder_decoder_attention
=
True
,
)
self
.
encoder_attn_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
LayerNorm
(
self
.
embed_dim
,
export
=
export
)
self
.
need_attn
=
True
self
.
onnx_trace
=
False
def
prepare_for_onnx_export_
(
self
):
self
.
onnx_trace
=
True
def
forward
(
self
,
input
):
"""
Args:
input (Tuple):
input[0] (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
input[1] (Tensor): encoder output of shape `(batch, src_len, embed_dim)`
input[2] (ByteTensor/FloatTensor): encoder padding mask -
binary ByteTensor of shape `(batch, src_len)` where padding elements
are indicated by ``1``.
Returns:
output (Tuple):
output[0] (Tensor): encoded output of shape `(batch, src_len, embed_dim)`
output[1] (ByteTensor/FloatTensor): encoder padding mask
output[2] (LongTensor): previous decoder outputs
"""
# Note: incremental state is not yet supported
mt_task
=
False
if
isinstance
(
input
,
tuple
):
x
=
input
[
0
]
encoder_out
=
input
[
1
]
encoder_padding_mask
=
input
[
2
]
incremental_state
=
None
mt_task
=
True
else
:
x
=
input
encoder_out
=
None
encoder_padding_mask
=
None
incremental_state
=
None
if
incremental_state
is
None
:
self_attn_mask
=
self
.
buffered_future_mask
(
x
)
else
:
self_attn_mask
=
None
# TODO: add back prev_self_attn_state, prev_attn_state,
# self_attn_padding_mask
prev_self_attn_state
=
None
prev_attn_state
=
None
self_attn_padding_mask
=
None
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
self_attn_layer_norm
,
x
,
before
=
True
)
if
prev_self_attn_state
is
not
None
:
if
incremental_state
is
None
:
incremental_state
=
{}
prev_key
,
prev_value
=
prev_self_attn_state
saved_state
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
}
self
.
self_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
self_attn_layer_norm
,
x
,
after
=
True
)
if
self
.
encoder_attn
is
not
None
:
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
encoder_attn_layer_norm
,
x
,
before
=
True
)
if
prev_attn_state
is
not
None
:
if
incremental_state
is
None
:
incremental_state
=
{}
prev_key
,
prev_value
=
prev_attn_state
saved_state
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
}
self
.
encoder_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
need_weights
=
(
not
self
.
training
and
self
.
need_attn
),
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
encoder_attn_layer_norm
,
x
,
after
=
True
)
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
final_layer_norm
,
x
,
before
=
True
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
F
.
dropout
(
x
,
p
=
self
.
activation_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
final_layer_norm
,
x
,
after
=
True
)
if
mt_task
:
return
(
x
,
encoder_out
,
encoder_padding_mask
)
return
x
def
buffered_future_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
0
)
if
(
not
hasattr
(
self
,
"_future_mask"
)
or
self
.
_future_mask
is
None
or
self
.
_future_mask
.
device
!=
tensor
.
device
):
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
if
self
.
_future_mask
.
size
(
0
)
<
dim
:
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_future_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_future_mask
[:
dim
,
:
dim
]
def
maybe_layer_norm
(
self
,
layer_norm
,
x
,
before
=
False
,
after
=
False
):
assert
before
^
after
if
after
^
self
.
normalize_before
:
return
layer_norm
(
x
)
else
:
return
x
def
make_generation_fast_
(
self
,
need_attn
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
return
m
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/pipeline_parallel_transformer/model.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.model_parallel.models.pipeline_parallel_transformer.layers
import
(
Embedding
,
TransformerDecoderEmbedding
,
TransformerDecoderLayer
,
TransformerDecoderOutputLayer
,
TransformerEncoderEmbedding
,
TransformerEncoderLayer
,
TransformerEncoderLayerNorm
,
)
from
fairseq.models
import
(
BaseFairseqModel
,
FairseqDecoder
,
FairseqEncoder
,
register_model
,
register_model_architecture
,
)
from
fairseq.models.fairseq_encoder
import
EncoderOut
from
fairseq.models.transformer
import
(
base_architecture
,
transformer_iwslt_de_en
,
transformer_wmt_en_de_big
,
)
from
fairseq.modules
import
SinusoidalPositionalEmbedding
logger
=
logging
.
getLogger
(
__name__
)
DEFAULT_MAX_SOURCE_POSITIONS
=
1024
DEFAULT_MAX_TARGET_POSITIONS
=
1024
TORCH_PIPE
=
False
RPC_INIT
=
False
def
import_pipe
():
global
TORCH_PIPE
global
RPC_INIT
try
:
from
torch.distributed.pipeline.sync
import
Pipe
# noqa
global
Pipe
from
torch.distributed.pipeline.sync.utils
import
partition_model
global
partition_model
from
torch.distributed
import
rpc
import
tempfile
TORCH_PIPE
=
True
# Initialize single process RPC agent since TORCH_PIPE requires
# RRef. RRef depends on RPC being initialized and as a result we initialize
# RPC with a single node.
tmpfile
=
tempfile
.
NamedTemporaryFile
()
if
not
RPC_INIT
:
rpc
.
init_rpc
(
name
=
"worker"
,
rank
=
0
,
world_size
=
1
,
rpc_backend_options
=
rpc
.
TensorPipeRpcBackendOptions
(
init_method
=
"file://{}"
.
format
(
tmpfile
.
name
),
),
)
RPC_INIT
=
True
logger
.
info
(
"Using torch pipe"
)
except
ImportError
:
try
:
from
fairscale.nn
import
Pipe
# noqa
logger
.
info
(
"Using fairscale pipe"
)
except
ImportError
:
raise
ImportError
(
"Please install fairscale with: pip install fairscale"
)
@
register_model
(
"pipeline_parallel_transformer"
)
class
PipelineParallelTransformerModel
(
BaseFairseqModel
):
def
__init__
(
self
,
encoder
,
decoder
,
balance
,
devices
,
chunks
,
checkpoint
):
import_pipe
()
super
().
__init__
()
assert
isinstance
(
encoder
,
FairseqEncoder
)
assert
isinstance
(
decoder
,
FairseqDecoder
)
encoder_module_list
=
(
[
encoder
.
embedding_layer
]
+
list
(
encoder
.
encoder_layers
)
+
[
encoder
.
final_layer_norm
]
)
self
.
num_encoder_modules
=
len
(
encoder_module_list
)
decoder_module_list
=
(
[
decoder
.
embedding_layer
]
+
list
(
decoder
.
decoder_layers
)
+
[
decoder
.
decoder_output_layer
]
)
self
.
num_decoder_modules
=
len
(
decoder_module_list
)
module_list
=
encoder_module_list
+
decoder_module_list
self
.
devices
=
devices
if
TORCH_PIPE
:
self
.
model
=
Pipe
(
partition_model
(
nn
.
Sequential
(
*
module_list
),
balance
,
devices
),
chunks
=
chunks
,
checkpoint
=
checkpoint
,
)
else
:
self
.
model
=
Pipe
(
nn
.
Sequential
(
*
module_list
),
balance
=
balance
,
devices
=
devices
,
chunks
=
chunks
,
checkpoint
=
checkpoint
,
)
self
.
encoder_max_positions
=
self
.
max_positions_helper
(
encoder
.
embedding_layer
,
"max_source_positions"
)
self
.
decoder_max_positions
=
self
.
max_positions_helper
(
decoder
.
embedding_layer
,
"max_target_positions"
)
self
.
adaptive_softmax
=
getattr
(
decoder
,
"adaptive_softmax"
,
None
)
# Note: To be populated during inference
self
.
encoder
=
None
self
.
decoder
=
None
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
):
if
self
.
training
:
input_lst
=
[
src_tokens
,
src_lengths
,
prev_output_tokens
]
input
=
tuple
(
i
.
to
(
self
.
devices
[
0
],
non_blocking
=
True
)
for
i
in
input_lst
)
if
TORCH_PIPE
:
return
self
.
model
(
input
).
local_value
()
else
:
return
self
.
model
(
input
)
else
:
assert
self
.
encoder
is
not
None
and
self
.
decoder
is
not
None
,
(
"encoder and decoder need to be initialized by "
+
"calling the `prepare_for_inference_()` method"
)
encoder_output_tuple
=
self
.
encoder
(
input
)
return
self
.
decoder
(
encoder_output_tuple
)
def
prepare_for_inference_
(
self
,
cfg
):
if
self
.
encoder
is
not
None
and
self
.
decoder
is
not
None
:
logger
.
info
(
"Encoder and Decoder already initialized"
)
return
encoder_module_list
=
[]
decoder_module_list
=
[]
module_count
=
0
for
partition
in
self
.
model
.
partitions
:
for
module
in
partition
:
if
module_count
<
self
.
num_encoder_modules
:
encoder_module_list
.
append
(
module
)
else
:
decoder_module_list
.
append
(
module
)
module_count
+=
1
self
.
model
=
None
self
.
encoder
=
TransformerEncoder
(
cfg
.
distributed_training
,
None
,
None
,
encoder_module_list
)
self
.
decoder
=
TransformerDecoder
(
cfg
.
distributed_training
,
None
,
None
,
decoder_module_list
=
decoder_module_list
,
)
@
staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
# fmt: off
parser
.
add_argument
(
'--activation-fn'
,
choices
=
utils
.
get_available_activation_fns
(),
help
=
'activation function to use'
)
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability'
)
parser
.
add_argument
(
'--attention-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability for attention weights'
)
parser
.
add_argument
(
'--activation-dropout'
,
'--relu-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability after activation in FFN.'
)
parser
.
add_argument
(
'--encoder-embed-path'
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained encoder embedding'
)
parser
.
add_argument
(
'--encoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension'
)
parser
.
add_argument
(
'--encoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--encoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder layers'
)
parser
.
add_argument
(
'--encoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder attention heads'
)
parser
.
add_argument
(
'--encoder-normalize-before'
,
action
=
'store_true'
,
help
=
'apply layernorm before each encoder block'
)
parser
.
add_argument
(
'--encoder-learned-pos'
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the encoder'
)
parser
.
add_argument
(
'--decoder-embed-path'
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained decoder embedding'
)
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension'
)
parser
.
add_argument
(
'--decoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--decoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder layers'
)
parser
.
add_argument
(
'--decoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder attention heads'
)
parser
.
add_argument
(
'--decoder-learned-pos'
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the decoder'
)
parser
.
add_argument
(
'--decoder-normalize-before'
,
action
=
'store_true'
,
help
=
'apply layernorm before each decoder block'
)
parser
.
add_argument
(
'--share-decoder-input-output-embed'
,
action
=
'store_true'
,
help
=
'share decoder input and output embeddings'
)
parser
.
add_argument
(
'--share-all-embeddings'
,
action
=
'store_true'
,
help
=
'share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)'
)
parser
.
add_argument
(
'--no-token-positional-embeddings'
,
default
=
False
,
action
=
'store_true'
,
help
=
'if set, disables positional embeddings (outside self attention)'
)
parser
.
add_argument
(
'--adaptive-softmax-cutoff'
,
metavar
=
'EXPR'
,
help
=
'comma separated list of adaptive softmax cutoff points. '
'Must be used with adaptive_loss criterion'
),
parser
.
add_argument
(
'--adaptive-softmax-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'sets adaptive softmax dropout for the tail projections'
)
parser
.
add_argument
(
'--num-embedding-chunks'
,
type
=
int
,
metavar
=
'N'
,
default
=
1
,
help
=
'Number of embedding layer chunks (enables more even distribution'
'of optimizer states across data parallel nodes'
'when using optimizer state sharding and'
'a big embedding vocabulary)'
)
# fmt: on
@
classmethod
def
build_model_base
(
cls
,
args
,
task
):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_architecture
(
args
)
if
not
hasattr
(
args
,
"max_source_positions"
):
args
.
max_source_positions
=
DEFAULT_MAX_SOURCE_POSITIONS
if
not
hasattr
(
args
,
"max_target_positions"
):
args
.
max_target_positions
=
DEFAULT_MAX_TARGET_POSITIONS
src_dict
,
tgt_dict
=
task
.
source_dictionary
,
task
.
target_dictionary
def
build_embedding
(
dictionary
,
embed_dim
,
path
=
None
,
num_embed_chunks
=
1
):
assert
embed_dim
%
num_embed_chunks
==
0
,
(
f
"Number of embedding chunks =
{
num_embed_chunks
}
should be "
+
f
"divisible by the embedding dimension =
{
embed_dim
}
"
)
assert
path
is
None
or
num_embed_chunks
==
1
,
(
"Loading embedding from a path with number of embedding chunks > 1"
+
" is not yet supported"
)
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
# if provided, load from preloaded dictionaries
if
path
:
emb
=
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
embed_dict
=
utils
.
parse_embedding
(
path
)
utils
.
load_embedding
(
embed_dict
,
dictionary
,
emb
)
else
:
embed_chunk_dim
=
embed_dim
//
num_embed_chunks
emb
=
nn
.
ModuleList
()
for
i
in
range
(
num_embed_chunks
):
emb
.
append
(
Embedding
(
num_embeddings
,
embed_chunk_dim
,
padding_idx
))
return
emb
num_embed_chunks
=
args
.
num_embedding_chunks
if
args
.
share_all_embeddings
:
if
src_dict
!=
tgt_dict
:
raise
ValueError
(
"--share-all-embeddings requires a joined dictionary"
)
if
args
.
encoder_embed_dim
!=
args
.
decoder_embed_dim
:
raise
ValueError
(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if
args
.
decoder_embed_path
and
(
args
.
decoder_embed_path
!=
args
.
encoder_embed_path
):
raise
ValueError
(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
,
args
.
encoder_embed_path
,
num_embed_chunks
,
)
decoder_embed_tokens
=
encoder_embed_tokens
args
.
share_decoder_input_output_embed
=
True
else
:
assert
args
.
share_decoder_input_output_embed
or
num_embed_chunks
==
1
,
(
"Not sharing decoder I/O embeddings is not yet supported with number of "
+
"embedding chunks > 1"
)
encoder_embed_tokens
=
build_embedding
(
src_dict
,
args
.
encoder_embed_dim
,
args
.
encoder_embed_path
,
num_embed_chunks
,
)
decoder_embed_tokens
=
build_embedding
(
tgt_dict
,
args
.
decoder_embed_dim
,
args
.
decoder_embed_path
,
num_embed_chunks
,
)
encoder
=
cls
.
build_encoder
(
args
,
src_dict
,
encoder_embed_tokens
)
decoder
=
cls
.
build_decoder
(
args
,
tgt_dict
,
decoder_embed_tokens
)
return
(
encoder
,
decoder
)
@
classmethod
def
build_encoder
(
cls
,
args
,
src_dict
,
embed_tokens
):
return
TransformerEncoder
(
args
,
src_dict
,
embed_tokens
)
@
classmethod
def
build_decoder
(
cls
,
args
,
tgt_dict
,
embed_tokens
):
return
TransformerDecoder
(
args
,
tgt_dict
,
embed_tokens
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
encoder
,
decoder
=
cls
.
build_model_base
(
args
,
task
)
return
PipelineParallelTransformerModel
(
encoder
=
encoder
,
decoder
=
decoder
,
balance
=
utils
.
eval_str_list
(
args
.
pipeline_balance
,
type
=
int
),
devices
=
utils
.
eval_str_list
(
args
.
pipeline_devices
,
type
=
int
),
chunks
=
args
.
pipeline_chunks
,
checkpoint
=
args
.
pipeline_checkpoint
,
)
def
output_layer
(
self
,
features
,
**
kwargs
):
"""Project features to the default output size (typically vocabulary size)."""
return
self
.
decoder
.
output_layer
(
features
,
**
kwargs
)
def
max_positions
(
self
):
"""Maximum length supported by the model."""
return
(
self
.
encoder_max_positions
,
self
.
decoder_max_positions
)
def
max_positions_helper
(
self
,
embedding_layer
,
max_positions_field
=
"max_source_positions"
):
"""Maximum input length supported by the encoder or decoder."""
if
embedding_layer
.
embed_positions
is
None
:
return
getattr
(
embedding_layer
,
max_positions_field
)
return
min
(
getattr
(
embedding_layer
,
max_positions_field
),
embedding_layer
.
embed_positions
.
max_positions
,
)
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
"""Get normalized probabilities (or log probs) from a net's output."""
if
hasattr
(
self
,
"adaptive_softmax"
)
and
self
.
adaptive_softmax
is
not
None
:
if
sample
is
not
None
:
assert
"target"
in
sample
target
=
sample
[
"target"
]
else
:
target
=
None
out
=
self
.
adaptive_softmax
.
get_log_prob
(
net_output
,
target
=
target
)
return
out
.
exp_
()
if
not
log_probs
else
out
# A Pipe() module returns a tuple of tensors as the output.
# In this case, the tuple has one element - the output tensor of logits
logits
=
net_output
if
isinstance
(
net_output
,
torch
.
Tensor
)
else
net_output
[
0
]
if
log_probs
:
return
utils
.
log_softmax
(
logits
,
dim
=-
1
,
onnx_trace
=
False
)
else
:
return
utils
.
softmax
(
logits
,
dim
=-
1
,
onnx_trace
=
False
)
def
max_decoder_positions
(
self
):
"""Maximum length supported by the decoder."""
return
self
.
decoder_max_positions
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
,
model_cfg
=
None
):
"""Copies parameters and buffers from *state_dict* into this module and
its descendants.
Overrides the method in :class:`nn.Module`. Compared with that method
this additionally "upgrades" *state_dicts* from old checkpoints.
"""
self
.
upgrade_state_dict
(
state_dict
)
is_regular_transformer
=
not
any
(
"model.partitions"
in
k
for
k
in
state_dict
)
if
is_regular_transformer
:
state_dict
=
self
.
convert_to_pipeline_parallel_state_dict
(
state_dict
)
return
super
().
load_state_dict
(
state_dict
,
strict
)
def
convert_to_pipeline_parallel_state_dict
(
self
,
state_dict
):
new_state_dict
=
self
.
state_dict
()
encoder_layer_idx
=
0
decoder_layer_idx
=
0
encoder_key_suffixes
=
[
"self_attn.k_proj.weight"
,
"self_attn.k_proj.bias"
,
"self_attn.v_proj.weight"
,
"self_attn.v_proj.bias"
,
"self_attn.q_proj.weight"
,
"self_attn.q_proj.bias"
,
"self_attn.out_proj.weight"
,
"self_attn.out_proj.bias"
,
"self_attn_layer_norm.weight"
,
"self_attn_layer_norm.bias"
,
"fc1.weight"
,
"fc1.bias"
,
"fc2.weight"
,
"fc2.bias"
,
"final_layer_norm.weight"
,
"final_layer_norm.bias"
,
]
decoder_key_suffixes
=
[
"self_attn.k_proj.weight"
,
"self_attn.k_proj.bias"
,
"self_attn.v_proj.weight"
,
"self_attn.v_proj.bias"
,
"self_attn.q_proj.weight"
,
"self_attn.q_proj.bias"
,
"self_attn.out_proj.weight"
,
"self_attn.out_proj.bias"
,
"self_attn_layer_norm.weight"
,
"self_attn_layer_norm.bias"
,
"encoder_attn.k_proj.weight"
,
"encoder_attn.k_proj.bias"
,
"encoder_attn.v_proj.weight"
,
"encoder_attn.v_proj.bias"
,
"encoder_attn.q_proj.weight"
,
"encoder_attn.q_proj.bias"
,
"encoder_attn.out_proj.weight"
,
"encoder_attn.out_proj.bias"
,
"encoder_attn_layer_norm.weight"
,
"encoder_attn_layer_norm.bias"
,
"fc1.weight"
,
"fc1.bias"
,
"fc2.weight"
,
"fc2.bias"
,
"final_layer_norm.weight"
,
"final_layer_norm.bias"
,
]
for
pid
,
partition
in
enumerate
(
self
.
model
.
partitions
):
logger
.
info
(
f
"Begin Partition
{
pid
}
"
)
for
mid
,
module
in
enumerate
(
partition
):
# fmt: off
if
isinstance
(
module
,
TransformerEncoderEmbedding
):
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.embed_tokens.weight'
]
=
state_dict
[
'encoder.embed_tokens.weight'
]
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.embed_positions._float_tensor'
]
=
state_dict
[
'encoder.embed_positions._float_tensor'
]
if
isinstance
(
module
,
TransformerEncoderLayer
):
for
suffix
in
encoder_key_suffixes
:
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.
{
suffix
}
'
]
=
state_dict
[
f
'encoder.layers.
{
encoder_layer_idx
}
.
{
suffix
}
'
]
encoder_layer_idx
+=
1
if
isinstance
(
module
,
TransformerDecoderLayer
):
for
suffix
in
decoder_key_suffixes
:
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.
{
suffix
}
'
]
=
state_dict
[
f
'decoder.layers.
{
decoder_layer_idx
}
.
{
suffix
}
'
]
decoder_layer_idx
+=
1
if
isinstance
(
module
,
TransformerEncoderLayerNorm
):
if
'encoder.layer_norm.weight'
in
state_dict
:
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.layer_norm.weight'
]
=
state_dict
[
'encoder.layer_norm.weight'
]
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.layer_norm.bias'
]
=
state_dict
[
'encoder.layer_norm.bias'
]
if
isinstance
(
module
,
TransformerDecoderEmbedding
):
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.embed_tokens.weight'
]
=
state_dict
[
'decoder.embed_tokens.weight'
]
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.embed_positions._float_tensor'
]
=
state_dict
[
'decoder.embed_positions._float_tensor'
]
if
isinstance
(
module
,
TransformerDecoderOutputLayer
):
new_state_dict
[
f
'model.partitions.
{
pid
}
.
{
mid
}
.output_projection.weight'
]
=
state_dict
[
'decoder.output_projection.weight'
]
# fmt: on
return
new_state_dict
class
TransformerEncoder
(
FairseqEncoder
):
"""
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
,
encoder_module_list
=
None
):
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
"version"
,
torch
.
Tensor
([
3
]))
import_pipe
()
self
.
use_pipeline
=
encoder_module_list
is
not
None
if
not
self
.
use_pipeline
:
self
.
embedding_layer
=
TransformerEncoderEmbedding
(
args
,
embed_tokens
)
self
.
encoder_layers
=
nn
.
Sequential
(
*
[
TransformerEncoderLayer
(
args
)
for
i
in
range
(
args
.
encoder_layers
)]
)
if
isinstance
(
embed_tokens
,
nn
.
ModuleList
):
emb_dim
=
sum
(
e
.
embedding_dim
for
e
in
embed_tokens
)
else
:
emb_dim
=
embed_tokens
.
embedding_dim
self
.
final_layer_norm
=
TransformerEncoderLayerNorm
(
args
,
emb_dim
)
else
:
encoder_balance
=
utils
.
eval_str_list
(
args
.
pipeline_encoder_balance
,
type
=
int
)
encoder_devices
=
utils
.
eval_str_list
(
args
.
pipeline_encoder_devices
,
type
=
int
)
assert
sum
(
encoder_balance
)
==
len
(
encoder_module_list
),
(
f
"Sum of encoder_balance=
{
encoder_balance
}
is not equal "
+
f
"to num_encoder_modules=
{
len
(
encoder_module_list
)
}
"
)
if
TORCH_PIPE
:
self
.
model
=
Pipe
(
module
=
partition_model
(
nn
.
Sequential
(
*
encoder_module_list
),
encoder_balance
,
encoder_devices
,
),
chunks
=
args
.
pipeline_chunks
,
checkpoint
=
args
.
pipeline_checkpoint
,
)
else
:
self
.
model
=
Pipe
(
module
=
nn
.
Sequential
(
*
encoder_module_list
),
balance
=
encoder_balance
,
devices
=
encoder_devices
,
chunks
=
args
.
pipeline_chunks
,
checkpoint
=
args
.
pipeline_checkpoint
,
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
"""
Args:
input_tuple(
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
)
Returns:
output_tuple(
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- prev_output_tokens
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
)
"""
dummy_prev_output_tokens
=
torch
.
zeros
(
1
,
dtype
=
src_tokens
.
dtype
,
device
=
src_tokens
.
device
)
input_tuple
=
(
src_tokens
,
src_lengths
,
dummy_prev_output_tokens
)
if
self
.
use_pipeline
:
input_tuple
=
tuple
(
i
.
to
(
self
.
model
.
devices
[
0
])
for
i
in
input_tuple
)
if
TORCH_PIPE
:
encoder_out
=
self
.
model
(
input_tuple
).
local_value
()
else
:
encoder_out
=
self
.
model
(
input_tuple
)
else
:
encoder_embed_output_tuple
=
self
.
embedding_layer
(
input_tuple
)
encoder_layers_output
=
self
.
encoder_layers
(
encoder_embed_output_tuple
)
encoder_out
=
self
.
final_layer_norm
(
encoder_layers_output
)
# first element is the encoder output
# second element is the encoder padding mask
# the remaining elements of EncoderOut are not computed by
# the PipelineParallelTransformer
return
EncoderOut
(
encoder_out
[
0
],
encoder_out
[
1
],
None
,
None
,
None
,
None
)
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if
encoder_out
.
encoder_out
is
not
None
:
encoder_out
=
encoder_out
.
_replace
(
encoder_out
=
encoder_out
.
encoder_out
.
index_select
(
1
,
new_order
)
)
if
encoder_out
.
encoder_padding_mask
is
not
None
:
encoder_out
=
encoder_out
.
_replace
(
encoder_padding_mask
=
encoder_out
.
encoder_padding_mask
.
index_select
(
0
,
new_order
)
)
if
encoder_out
.
encoder_embedding
is
not
None
:
encoder_out
=
encoder_out
.
_replace
(
encoder_embedding
=
encoder_out
.
encoder_embedding
.
index_select
(
0
,
new_order
)
)
if
encoder_out
.
encoder_states
is
not
None
:
for
idx
,
state
in
enumerate
(
encoder_out
.
encoder_states
):
encoder_out
.
encoder_states
[
idx
]
=
state
.
index_select
(
1
,
new_order
)
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
if
self
.
embedding_layer
.
embed_positions
is
None
:
return
self
.
embedding_layer
.
max_source_positions
return
min
(
self
.
embedding_layer
.
max_source_positions
,
self
.
embedding_layer
.
embed_positions
.
max_positions
,
)
class
TransformerDecoder
(
FairseqDecoder
):
"""
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
,
no_encoder_attn
=
False
,
decoder_module_list
=
None
,
):
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
"version"
,
torch
.
Tensor
([
3
]))
import_pipe
()
self
.
use_pipeline
=
decoder_module_list
is
not
None
if
not
self
.
use_pipeline
:
self
.
embedding_layer
=
TransformerDecoderEmbedding
(
args
,
embed_tokens
)
self
.
decoder_layers
=
nn
.
Sequential
(
*
[
TransformerDecoderLayer
(
args
,
no_encoder_attn
)
for
_
in
range
(
args
.
decoder_layers
)
]
)
self
.
decoder_output_layer
=
TransformerDecoderOutputLayer
(
args
,
embed_tokens
,
dictionary
)
else
:
decoder_balance
=
utils
.
eval_str_list
(
args
.
pipeline_decoder_balance
,
type
=
int
)
decoder_devices
=
utils
.
eval_str_list
(
args
.
pipeline_decoder_devices
,
type
=
int
)
assert
sum
(
decoder_balance
)
==
len
(
decoder_module_list
),
(
f
"Sum of decoder_balance=
{
decoder_balance
}
is not equal "
+
f
"to num_decoder_modules=
{
len
(
decoder_module_list
)
}
"
)
if
TORCH_PIPE
:
self
.
model
=
Pipe
(
module
=
partition_model
(
nn
.
Sequential
(
*
decoder_module_list
),
decoder_balance
,
decoder_devices
,
),
chunks
=
args
.
pipeline_chunks
,
checkpoint
=
args
.
pipeline_checkpoint
,
)
else
:
self
.
model
=
Pipe
(
module
=
nn
.
Sequential
(
*
decoder_module_list
),
balance
=
decoder_balance
,
devices
=
decoder_devices
,
chunks
=
args
.
pipeline_chunks
,
checkpoint
=
args
.
pipeline_checkpoint
,
)
def
forward
(
self
,
prev_output_tokens
,
encoder_out
=
None
,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
input_tuple
=
(
encoder_out
.
encoder_out
,
encoder_out
.
encoder_padding_mask
,
prev_output_tokens
,
)
if
self
.
use_pipeline
:
input_tuple
=
tuple
(
i
.
to
(
self
.
model
.
devices
[
0
])
for
i
in
input_tuple
)
if
TORCH_PIPE
:
return
(
self
.
model
(
input_tuple
).
local_value
(),)
else
:
return
(
self
.
model
(
input_tuple
),)
else
:
embed_layer_output
=
self
.
embedding_layer
(
input_tuple
)
state
=
self
.
decoder_layers
(
embed_layer_output
)
return
(
self
.
decoder_output_layer
(
state
),)
def
output_layer
(
self
,
features
,
**
kwargs
):
"""Project features to the vocabulary size."""
if
self
.
adaptive_softmax
is
None
:
# project back to size of vocabulary
if
self
.
share_input_output_embed
:
return
F
.
linear
(
features
,
self
.
embed_tokens
.
weight
)
else
:
return
F
.
linear
(
features
,
self
.
embed_out
)
else
:
return
features
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
if
self
.
embedding_layer
.
embed_positions
is
None
:
return
self
.
embedding_layer
.
max_target_positions
return
min
(
self
.
embedding_layer
.
max_target_positions
,
self
.
embedding_layer
.
embed_positions
.
max_positions
,
)
def
buffered_future_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
0
)
if
(
not
hasattr
(
self
,
"_future_mask"
)
or
self
.
_future_mask
is
None
or
self
.
_future_mask
.
device
!=
tensor
.
device
or
self
.
_future_mask
.
size
(
0
)
<
dim
):
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new
(
dim
,
dim
)),
1
)
return
self
.
_future_mask
[:
dim
,
:
dim
]
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
weights_key
=
"{}.embed_positions.weights"
.
format
(
name
)
if
weights_key
in
state_dict
:
del
state_dict
[
weights_key
]
state_dict
[
"{}.embed_positions._float_tensor"
.
format
(
name
)
]
=
torch
.
FloatTensor
(
1
)
for
i
in
range
(
len
(
self
.
layers
)):
# update layer norms
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"encoder_attn_layer_norm"
,
"2"
:
"final_layer_norm"
,
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layers.{}.layer_norms.{}.{}"
.
format
(
name
,
i
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.layers.{}.{}.{}"
.
format
(
name
,
i
,
new
,
m
)
]
=
state_dict
[
k
]
del
state_dict
[
k
]
version_key
=
"{}.version"
.
format
(
name
)
if
utils
.
item
(
state_dict
.
get
(
version_key
,
torch
.
Tensor
([
1
]))[
0
])
<=
2
:
# earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
state_dict
[
version_key
]
=
torch
.
Tensor
([
1
])
return
state_dict
@
register_model_architecture
(
"pipeline_parallel_transformer"
,
"transformer_iwslt_de_en_pipeline_parallel"
)
def
transformer_iwslt_de_en_dist
(
args
):
transformer_iwslt_de_en
(
args
)
@
register_model_architecture
(
"pipeline_parallel_transformer"
,
"transformer_wmt_en_de_big_pipeline_parallel"
)
def
transformer_wmt_en_de_big_dist
(
args
):
transformer_wmt_en_de_big
(
args
)
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/roberta/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.model
import
*
# noqa
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/roberta/model.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
RoBERTa: A Robustly Optimized BERT Pretraining Approach.
"""
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.model_parallel.models.transformer
import
ModelParallelTransformerEncoder
from
fairseq.models
import
register_model
,
register_model_architecture
from
fairseq.models.roberta
import
(
roberta_base_architecture
,
roberta_prenorm_architecture
,
RobertaEncoder
,
RobertaModel
,
)
from
fairseq.modules
import
LayerNorm
try
:
from
fairseq.model_parallel.megatron.mpu
import
(
copy_to_model_parallel_region
,
gather_from_model_parallel_region
,
ColumnParallelLinear
,
VocabParallelEmbedding
,
)
has_megatron_submodule
=
True
except
(
ImportError
,
ModuleNotFoundError
):
has_megatron_submodule
=
False
logger
=
logging
.
getLogger
(
__name__
)
@
register_model
(
"model_parallel_roberta"
)
class
ModelParallelRobertaModel
(
RobertaModel
):
def
__init__
(
self
,
args
,
encoder
):
super
().
__init__
(
args
,
encoder
)
self
.
classification_heads
=
nn
.
ModuleDict
()
@
staticmethod
def
add_args
(
parser
):
RobertaModel
.
add_args
(
parser
)
parser
.
add_argument
(
"--no-final-layer-norm"
,
action
=
"store_true"
,
help
=
(
"don't add final layernorm (only applicable when "
"--encoder-normalize-before=True"
),
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
# make sure all arguments are present
base_architecture
(
args
)
task
.
source_dictionary
.
pad_to_multiple_
(
args
.
model_parallel_size
*
8
)
task
.
target_dictionary
.
pad_to_multiple_
(
args
.
model_parallel_size
*
8
)
if
not
hasattr
(
args
,
"max_positions"
):
args
.
max_positions
=
args
.
tokens_per_sample
if
getattr
(
args
,
"untie_weights_roberta"
,
False
):
raise
NotImplementedError
(
"--untie-weights-roberta is not supported in model parallel mode"
)
encoder
=
ModelParallelRobertaEncoder
(
args
,
task
.
source_dictionary
)
return
cls
(
args
,
encoder
)
def
forward
(
self
,
src_tokens
,
features_only
=
False
,
return_all_hiddens
=
False
,
classification_head_name
=
None
,
**
kwargs
):
if
classification_head_name
is
not
None
:
features_only
=
True
x
,
extra
=
self
.
encoder
(
src_tokens
,
features_only
,
return_all_hiddens
,
**
kwargs
)
if
classification_head_name
is
not
None
:
x
=
self
.
classification_heads
[
classification_head_name
](
x
)
return
x
,
extra
def
register_classification_head
(
self
,
name
,
num_classes
=
None
,
inner_dim
=
None
,
**
kwargs
):
"""Register a classification head."""
if
name
in
self
.
classification_heads
:
prev_num_classes
=
self
.
classification_heads
[
name
].
out_proj
.
out_features
prev_inner_dim
=
self
.
classification_heads
[
name
].
dense
.
out_features
if
num_classes
!=
prev_num_classes
or
inner_dim
!=
prev_inner_dim
:
logger
.
warning
(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})"
.
format
(
name
,
num_classes
,
prev_num_classes
,
inner_dim
,
prev_inner_dim
)
)
self
.
classification_heads
[
name
]
=
ModelParallelRobertaClassificationHead
(
self
.
args
.
encoder_embed_dim
,
inner_dim
or
self
.
args
.
encoder_embed_dim
,
num_classes
,
self
.
args
.
pooler_activation_fn
,
self
.
args
.
pooler_dropout
,
)
class
ModelParallelRobertaLMHead
(
nn
.
Module
):
"""Head for masked language modeling."""
def
__init__
(
self
,
embed_dim
,
output_dim
,
activation_fn
,
weight
=
None
):
super
().
__init__
()
self
.
dense
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
gather_output
=
True
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation_fn
)
self
.
layer_norm
=
LayerNorm
(
embed_dim
)
if
weight
is
None
:
weight
=
nn
.
Linear
(
embed_dim
,
output_dim
,
bias
=
False
).
weight
self
.
weight
=
weight
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
output_dim
))
def
forward
(
self
,
features
,
masked_tokens
=
None
,
**
kwargs
):
# Only project the unmasked tokens while training,
# saves both memory and computation
if
masked_tokens
is
not
None
:
features
=
features
[
masked_tokens
,
:]
x
=
self
.
dense
(
features
)
x
=
self
.
activation_fn
(
x
)
x
=
self
.
layer_norm
(
x
)
x
=
copy_to_model_parallel_region
(
x
)
# project back to size of vocabulary with bias
x
=
F
.
linear
(
x
,
self
.
weight
)
x
=
gather_from_model_parallel_region
(
x
).
contiguous
()
x
=
x
+
self
.
bias
return
x
class
ModelParallelRobertaClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
def
__init__
(
self
,
input_dim
,
inner_dim
,
num_classes
,
activation_fn
,
pooler_dropout
):
super
().
__init__
()
self
.
dense
=
ColumnParallelLinear
(
input_dim
,
inner_dim
,
gather_output
=
True
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation_fn
)
self
.
dropout
=
nn
.
Dropout
(
p
=
pooler_dropout
)
self
.
out_proj
=
nn
.
Linear
(
inner_dim
,
num_classes
)
def
forward
(
self
,
features
,
**
kwargs
):
x
=
features
[:,
0
,
:]
# take <s> token (equiv. to [CLS])
x
=
self
.
dropout
(
x
)
x
=
self
.
dense
(
x
)
x
=
self
.
activation_fn
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
class
ModelParallelRobertaEncoder
(
RobertaEncoder
):
"""RoBERTa encoder."""
def
__init__
(
self
,
args
,
dictionary
):
super
().
__init__
(
args
,
dictionary
)
assert
not
self
.
args
.
untie_weights_roberta
def
build_embedding
(
self
,
vocab_size
,
embedding_dim
,
padding_idx
):
return
VocabParallelEmbedding
(
vocab_size
,
embedding_dim
,
padding_idx
)
def
build_encoder
(
self
,
args
,
dictionary
,
embed_tokens
):
return
ModelParallelTransformerEncoder
(
args
,
dictionary
,
embed_tokens
)
def
build_lm_head
(
self
,
embed_dim
,
output_dim
,
activation_fn
,
weight
):
return
ModelParallelRobertaLMHead
(
embed_dim
,
output_dim
,
activation_fn
,
weight
)
@
register_model_architecture
(
"model_parallel_roberta"
,
"model_parallel_roberta"
)
def
base_architecture
(
args
):
args
.
no_final_layer_norm
=
getattr
(
args
,
"no_final_layer_norm"
,
False
)
# model parallel RoBERTa defaults to "Pre-LN" formulation
roberta_prenorm_architecture
(
args
)
# earlier versions of model parallel RoBERTa removed the final layer norm
@
register_model_architecture
(
"model_parallel_roberta"
,
"model_parallel_roberta_v1"
)
def
model_parallel_roberta_v1_architecture
(
args
):
args
.
no_final_layer_norm
=
getattr
(
args
,
"no_final_layer_norm"
,
True
)
base_architecture
(
args
)
@
register_model_architecture
(
"model_parallel_roberta"
,
"model_parallel_roberta_postnorm"
)
def
model_parallel_roberta_postnorm_architecture
(
args
):
# the original BERT/RoBERTa uses the "Post-LN" formulation
roberta_base_architecture
(
args
)
@
register_model_architecture
(
"model_parallel_roberta"
,
"model_parallel_roberta_base"
)
def
model_parallel_roberta_base_architecture
(
args
):
base_architecture
(
args
)
@
register_model_architecture
(
"model_parallel_roberta"
,
"model_parallel_roberta_large"
)
def
model_parallel_roberta_large_architecture
(
args
):
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
24
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
4096
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
base_architecture
(
args
)
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/transformer.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
torch.nn
as
nn
from
fairseq.model_parallel.modules
import
(
ModelParallelTransformerDecoderLayer
,
ModelParallelTransformerEncoderLayer
,
)
from
fairseq.models
import
register_model
from
fairseq.models.transformer
import
(
TransformerDecoder
,
TransformerEncoder
,
TransformerModel
,
)
try
:
from
fairseq.model_parallel.megatron.mpu
import
(
VocabParallelEmbedding
,
copy_to_model_parallel_region
,
gather_from_model_parallel_region
,
)
has_megatron_submodule
=
True
except
(
ImportError
,
ModuleNotFoundError
):
has_megatron_submodule
=
False
logger
=
logging
.
getLogger
(
__name__
)
@
register_model
(
"model_parallel_transformer"
)
class
ModelParallelTransformerModel
(
TransformerModel
):
"""
Model parallel Transformer model.
"""
@
classmethod
def
build_embedding
(
cls
,
args
,
dictionary
,
embed_dim
,
path
=
None
):
if
not
has_megatron_submodule
:
raise
ImportError
(
"
\n\n
Please install the megatron submodule:"
"
\n\n
git submodule update --init "
"fairseq/model_parallel/megatron"
)
dictionary
.
pad_to_multiple_
(
args
.
model_parallel_size
*
8
)
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
def
_vocab_init
(
tensor
,
**
kwargs
):
nn
.
init
.
normal_
(
tensor
,
mean
=
0
,
std
=
num_embeddings
**-
0.5
)
nn
.
init
.
constant_
(
tensor
[
1
],
0
)
emb
=
VocabParallelEmbedding
(
num_embeddings
,
embed_dim
,
padding_idx
,
init_method
=
_vocab_init
)
# if provided, load from preloaded dictionaries
if
path
:
raise
NotImplementedError
(
"Loading of embedding from path is not supported for model parallel"
)
return
emb
@
classmethod
def
build_encoder
(
cls
,
args
,
src_dict
,
embed_tokens
):
return
ModelParallelTransformerEncoder
(
args
,
src_dict
,
embed_tokens
)
@
classmethod
def
build_decoder
(
cls
,
args
,
tgt_dict
,
embed_tokens
):
return
ModelParallelTransformerDecoder
(
args
,
tgt_dict
,
embed_tokens
,
no_encoder_attn
=
getattr
(
args
,
"no_cross_attention"
,
False
),
)
class
ModelParallelTransformerEncoder
(
TransformerEncoder
):
"""
Model parallel Transformer encoder consisting of *args.encoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerEncoderLayer`.
"""
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
super
().
__init__
(
args
,
dictionary
,
embed_tokens
)
if
args
.
no_final_layer_norm
:
self
.
layer_norm
=
None
def
build_encoder_layer
(
self
,
args
):
return
ModelParallelTransformerEncoderLayer
(
args
)
class
ModelParallelTransformerDecoder
(
TransformerDecoder
):
"""
Model Parallel Transformer decoder consisting of *args.decoder_layers* layers. Each layer
is a :class:`ModelParallelTransformerDecoderLayer`.
"""
def
build_decoder_layer
(
self
,
args
,
no_encoder_attn
=
False
):
return
ModelParallelTransformerDecoderLayer
(
args
,
no_encoder_attn
)
def
output_layer
(
self
,
features
,
**
kwargs
):
"""Project features to the vocabulary size."""
if
not
self
.
share_input_output_embed
:
raise
NotImplementedError
(
"Model parallel training currently requires --share-decoder-input-output-embed"
)
features
=
copy_to_model_parallel_region
(
features
)
# project back to size of vocabulary
x
=
self
.
output_projection
(
features
)
if
getattr
(
self
.
args
,
"criterion"
)
!=
"vocab_parallel_cross_entropy"
:
x
=
gather_from_model_parallel_region
(
x
).
contiguous
()
return
x
PyTorch/NLP/new-Transformer/fairseq/model_parallel/models/transformer_lm.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch.nn
as
nn
from
fairseq.model_parallel.models.transformer
import
ModelParallelTransformerDecoder
from
fairseq.models
import
register_model
,
register_model_architecture
from
fairseq.models.transformer_lm
import
TransformerLanguageModel
try
:
from
fairseq.model_parallel.megatron.mpu
import
VocabParallelEmbedding
has_megatron_submodule
=
True
except
(
ImportError
,
ModuleNotFoundError
):
has_megatron_submodule
=
False
DEFAULT_MAX_TARGET_POSITIONS
=
1024
@
register_model
(
"model_parallel_transformer_lm"
)
class
ModelParallelTransformerLanguageModel
(
TransformerLanguageModel
):
@
staticmethod
def
add_args
(
parser
):
TransformerLanguageModel
.
add_args
(
parser
)
@
classmethod
def
build_model
(
cls
,
args
,
task
):
"""Build a new model instance."""
if
not
has_megatron_submodule
:
raise
ImportError
(
"
\n\n
Please install the megatron submodule:"
"
\n\n
git submodule update --init "
"fairseq/model_parallel/megatron"
)
# make sure all arguments are present in older models
base_lm_architecture
(
args
)
task
.
source_dictionary
.
pad_to_multiple_
(
args
.
model_parallel_size
*
8
)
task
.
target_dictionary
.
pad_to_multiple_
(
args
.
model_parallel_size
*
8
)
if
args
.
decoder_layers_to_keep
:
args
.
decoder_layers
=
len
(
args
.
decoder_layers_to_keep
.
split
(
","
))
if
getattr
(
args
,
"max_target_positions"
,
None
)
is
None
:
args
.
max_target_positions
=
getattr
(
args
,
"tokens_per_sample"
,
DEFAULT_MAX_TARGET_POSITIONS
)
if
args
.
character_embeddings
:
raise
NotImplementedError
(
"Character embeddings is not supported for model parallel"
)
elif
args
.
adaptive_input
:
raise
NotImplementedError
(
"Adaptive input is not supported for model parallel"
)
else
:
embed_tokens
=
cls
.
build_embedding
(
args
,
task
.
source_dictionary
,
args
.
decoder_input_dim
)
decoder
=
ModelParallelTransformerDecoder
(
args
,
task
.
target_dictionary
,
embed_tokens
,
no_encoder_attn
=
True
,
)
return
cls
(
decoder
)
@
classmethod
def
build_embedding
(
cls
,
args
,
dictionary
,
embed_dim
,
path
=
None
):
def
_vocab_init
(
tensor
,
**
kwargs
):
nn
.
init
.
normal_
(
tensor
,
mean
=
0
,
std
=
embed_dim
**-
0.5
)
nn
.
init
.
constant_
(
tensor
[
1
],
0
)
embed_tokens
=
VocabParallelEmbedding
(
len
(
dictionary
),
embed_dim
,
dictionary
.
pad
(),
init_method
=
_vocab_init
)
return
embed_tokens
def
base_lm_architecture
(
args
):
# backward compatibility for older model checkpoints
if
hasattr
(
args
,
"no_tie_adaptive_proj"
):
# previous models defined --no-tie-adaptive-proj, so use the existence of
# that option to determine if this is an "old" model checkpoint
args
.
no_decoder_final_norm
=
True
# old models always set this to True
if
args
.
no_tie_adaptive_proj
is
False
:
args
.
tie_adaptive_proj
=
True
if
hasattr
(
args
,
"decoder_final_norm"
):
args
.
no_decoder_final_norm
=
not
args
.
decoder_final_norm
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"relu"
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.0
)
args
.
activation_dropout
=
getattr
(
args
,
"activation_dropout"
,
0.0
)
args
.
relu_dropout
=
getattr
(
args
,
"relu_dropout"
,
0.0
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
512
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
2048
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
8
)
# Model training is not stable without this
args
.
decoder_normalize_before
=
True
args
.
no_decoder_final_norm
=
getattr
(
args
,
"no_decoder_final_norm"
,
False
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
adaptive_softmax_factor
=
getattr
(
args
,
"adaptive_softmax_factor"
,
4
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
"no_token_positional_embeddings"
,
False
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
False
)
args
.
character_embeddings
=
getattr
(
args
,
"character_embeddings"
,
False
)
args
.
character_filters
=
getattr
(
args
,
"character_filters"
,
"[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
,
)
args
.
character_embedding_dim
=
getattr
(
args
,
"character_embedding_dim"
,
4
)
args
.
char_embedder_highway_layers
=
getattr
(
args
,
"char_embedder_highway_layers"
,
2
)
args
.
adaptive_input
=
getattr
(
args
,
"adaptive_input"
,
False
)
args
.
adaptive_input_factor
=
getattr
(
args
,
"adaptive_input_factor"
,
4
)
args
.
adaptive_input_cutoff
=
getattr
(
args
,
"adaptive_input_cutoff"
,
None
)
args
.
tie_adaptive_weights
=
getattr
(
args
,
"tie_adaptive_weights"
,
False
)
args
.
tie_adaptive_proj
=
getattr
(
args
,
"tie_adaptive_proj"
,
False
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
False
)
args
.
decoder_layerdrop
=
getattr
(
args
,
"decoder_layerdrop"
,
0.0
)
args
.
decoder_layers_to_keep
=
getattr
(
args
,
"decoder_layers_to_keep"
,
None
)
args
.
layernorm_embedding
=
getattr
(
args
,
"layernorm_embedding"
,
False
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
args
.
quant_noise_pq
=
getattr
(
args
,
"quant_noise_pq"
,
0.0
)
args
.
quant_noise_pq_block_size
=
getattr
(
args
,
"quant_noise_pq_block_size"
,
8
)
args
.
quant_noise_scalar
=
getattr
(
args
,
"quant_noise_scalar"
,
0.0
)
args
.
add_bos_token
=
getattr
(
args
,
"add_bos_token"
,
False
)
@
register_model_architecture
(
"model_parallel_transformer_lm"
,
"transformer_lm_megatron"
)
def
transformer_lm_megatron
(
args
):
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
3072
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
3072
*
4
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
72
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
32
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.1
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"gelu"
)
base_lm_architecture
(
args
)
@
register_model_architecture
(
"model_parallel_transformer_lm"
,
"transformer_lm_megatron_11b"
)
def
transformer_lm_megatron_11b
(
args
):
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
3072
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
3072
*
6
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
72
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
32
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.1
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"gelu"
)
base_lm_architecture
(
args
)
PyTorch/NLP/new-Transformer/fairseq/model_parallel/modules/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
from
.multihead_attention
import
ModelParallelMultiheadAttention
from
.transformer_layer
import
(
ModelParallelTransformerEncoderLayer
,
ModelParallelTransformerDecoderLayer
,
)
__all__
=
[
"ModelParallelMultiheadAttention"
,
"ModelParallelTransformerEncoderLayer"
,
"ModelParallelTransformerDecoderLayer"
,
]
PyTorch/NLP/new-Transformer/fairseq/model_parallel/modules/multihead_attention.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
nn
from
fairseq
import
utils
from
fairseq.incremental_decoding_utils
import
with_incremental_state
from
fairseq.modules.fairseq_dropout
import
FairseqDropout
try
:
from
fairseq.model_parallel.megatron.mpu
import
(
ColumnParallelLinear
,
RowParallelLinear
,
get_cuda_rng_tracker
,
get_model_parallel_world_size
,
)
has_megatron_submodule
=
True
except
(
ImportError
,
ModuleNotFoundError
):
has_megatron_submodule
=
False
@
with_incremental_state
class
ModelParallelMultiheadAttention
(
nn
.
Module
):
"""Model parallel Multi-headed attention.
This performs the Multi-headed attention over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
kdim
=
None
,
vdim
=
None
,
dropout
=
0.0
,
bias
=
True
,
self_attention
=
False
,
encoder_decoder_attention
=
False
,
):
super
().
__init__
()
if
not
has_megatron_submodule
:
raise
ImportError
(
"
\n\n
Please install the megatron submodule:"
"
\n\n
git submodule update --init "
"fairseq/model_parallel/megatron"
)
self
.
embed_dim
=
embed_dim
self
.
kdim
=
kdim
if
kdim
is
not
None
else
embed_dim
self
.
vdim
=
vdim
if
vdim
is
not
None
else
embed_dim
self
.
qkv_same_dim
=
self
.
kdim
==
embed_dim
and
self
.
vdim
==
embed_dim
self
.
model_parallel_size
=
get_model_parallel_world_size
()
self
.
num_heads_partition
=
num_heads
//
self
.
model_parallel_size
assert
(
self
.
num_heads_partition
*
self
.
model_parallel_size
==
num_heads
),
"Number of heads must be divisible by model parallel size"
self
.
dropout_module
=
FairseqDropout
(
dropout
,
module_name
=
self
.
__class__
.
__name__
)
self
.
head_dim
=
embed_dim
//
num_heads
assert
(
self
.
head_dim
*
num_heads
==
self
.
embed_dim
),
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
self_attention
=
self_attention
self
.
encoder_decoder_attention
=
encoder_decoder_attention
assert
(
not
self
.
self_attention
or
self
.
qkv_same_dim
),
"Self-attention requires query, key and value to be of the same size"
self
.
k_proj
=
ColumnParallelLinear
(
self
.
kdim
,
embed_dim
,
bias
=
bias
,
gather_output
=
False
)
self
.
v_proj
=
ColumnParallelLinear
(
self
.
vdim
,
embed_dim
,
bias
=
bias
,
gather_output
=
False
)
self
.
q_proj
=
ColumnParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
gather_output
=
False
)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
bias
=
bias
,
input_is_parallel
=
True
)
def
forward
(
self
,
query
,
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
**
unused_kwargs
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
"""
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
is_tpu
=
query
.
device
.
type
==
"xla"
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute
# key and value if they are static
if
static_kv
:
assert
self
.
encoder_decoder_attention
and
not
self
.
self_attention
key
=
value
=
None
else
:
saved_state
=
None
if
self
.
self_attention
:
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
query
)
v
=
self
.
v_proj
(
query
)
elif
self
.
encoder_decoder_attention
:
# encoder-decoder attention
q
=
self
.
q_proj
(
query
)
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
key
)
else
:
assert
key
is
not
None
and
value
is
not
None
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
value
)
q
*=
self
.
scaling
q
=
(
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads_partition
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
k
is
not
None
:
k
=
(
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads_partition
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
v
is
not
None
:
v
=
(
v
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads_partition
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
saved_state
is
not
None
:
# saved states are stored with shape (bsz, num_heads_partition, seq_len, head_dim)
if
"prev_key"
in
saved_state
:
_prev_key
=
saved_state
[
"prev_key"
]
assert
_prev_key
is
not
None
prev_key
=
_prev_key
.
view
(
bsz
*
self
.
num_heads_partition
,
-
1
,
self
.
head_dim
)
if
static_kv
:
k
=
prev_key
else
:
assert
k
is
not
None
k
=
torch
.
cat
([
prev_key
,
k
],
dim
=
1
)
if
"prev_value"
in
saved_state
:
_prev_value
=
saved_state
[
"prev_value"
]
assert
_prev_value
is
not
None
prev_value
=
_prev_value
.
view
(
bsz
*
self
.
num_heads_partition
,
-
1
,
self
.
head_dim
)
if
static_kv
:
v
=
prev_value
else
:
assert
v
is
not
None
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
prev_key_padding_mask
:
Optional
[
Tensor
]
=
None
if
"prev_key_padding_mask"
in
saved_state
:
prev_key_padding_mask
=
saved_state
[
"prev_key_padding_mask"
]
assert
k
is
not
None
and
v
is
not
None
key_padding_mask
=
(
ModelParallelMultiheadAttention
.
_append_prev_key_padding_mask
(
key_padding_mask
=
key_padding_mask
,
prev_key_padding_mask
=
prev_key_padding_mask
,
batch_size
=
bsz
,
src_len
=
k
.
size
(
1
),
static_kv
=
static_kv
,
)
)
saved_state
[
"prev_key"
]
=
k
.
view
(
bsz
,
self
.
num_heads_partition
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_value"
]
=
v
.
view
(
bsz
,
self
.
num_heads_partition
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_key_padding_mask"
]
=
key_padding_mask
# In this branch incremental_state is never None
assert
incremental_state
is
not
None
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
assert
k
is
not
None
src_len
=
k
.
size
(
1
)
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if
key_padding_mask
is
not
None
and
key_padding_mask
.
dim
()
==
0
:
key_padding_mask
=
None
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads_partition
,
tgt_len
,
src_len
,
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
attn_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads_partition
,
tgt_len
,
src_len
)
if
not
is_tpu
:
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
),
float
(
"-inf"
),
)
else
:
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads_partition
,
tgt_len
,
src_len
)
attn_weights_float
=
utils
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
with
get_cuda_rng_tracker
().
fork
():
attn_probs
=
self
.
dropout_module
(
attn_weights
)
assert
v
is
not
None
attn
=
torch
.
bmm
(
attn_probs
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads_partition
,
tgt_len
,
self
.
head_dim
,
]
embed_dim_partition
=
embed_dim
//
self
.
model_parallel_size
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim_partition
)
attn
=
self
.
out_proj
(
attn
)
# return attn_weights None to keep the return type same as single gpu multihead attention
# This will be deprecated.
attn_weights
:
Optional
[
Tensor
]
=
None
return
attn
,
attn_weights
@
staticmethod
def
_append_prev_key_padding_mask
(
key_padding_mask
:
Optional
[
Tensor
],
prev_key_padding_mask
:
Optional
[
Tensor
],
batch_size
:
int
,
src_len
:
int
,
static_kv
:
bool
,
)
->
Optional
[
Tensor
]:
# saved key padding masks have shape (bsz, seq_len)
if
prev_key_padding_mask
is
not
None
and
static_kv
:
new_key_padding_mask
=
prev_key_padding_mask
elif
prev_key_padding_mask
is
not
None
and
key_padding_mask
is
not
None
:
new_key_padding_mask
=
torch
.
cat
(
[
prev_key_padding_mask
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
# During incremental decoding, as the padding token enters and
# leaves the frame, there will be a time when prev or current
# is None
elif
prev_key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
prev_key_padding_mask
.
size
(
1
))
if
prev_key_padding_mask
.
is_cuda
:
filler
=
filler
.
cuda
()
new_key_padding_mask
=
torch
.
cat
(
[
prev_key_padding_mask
.
float
(),
filler
.
float
()],
dim
=
1
)
elif
key_padding_mask
is
not
None
:
filler
=
torch
.
zeros
(
batch_size
,
src_len
-
key_padding_mask
.
size
(
1
))
if
key_padding_mask
.
is_cuda
:
filler
=
filler
.
cuda
()
new_key_padding_mask
=
torch
.
cat
(
[
filler
.
float
(),
key_padding_mask
.
float
()],
dim
=
1
)
else
:
new_key_padding_mask
=
prev_key_padding_mask
return
new_key_padding_mask
def
reorder_incremental_state
(
self
,
incremental_state
:
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]],
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
if
input_buffer
[
k
]
is
not
None
:
input_buffer
[
k
]
=
input_buffer
[
k
].
index_select
(
0
,
new_order
)
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
return
incremental_state
def
_get_input_buffer
(
self
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
)
->
Dict
[
str
,
Optional
[
Tensor
]]:
result
=
self
.
get_incremental_state
(
incremental_state
,
"attn_state"
)
if
result
is
not
None
:
return
result
else
:
empty_result
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{}
return
empty_result
def
_set_input_buffer
(
self
,
incremental_state
:
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]],
buffer
:
Dict
[
str
,
Optional
[
Tensor
]],
):
return
self
.
set_incremental_state
(
incremental_state
,
"attn_state"
,
buffer
)
PyTorch/NLP/new-Transformer/fairseq/model_parallel/modules/transformer_layer.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
fairseq.model_parallel.modules
import
ModelParallelMultiheadAttention
from
fairseq.modules
import
TransformerDecoderLayer
,
TransformerEncoderLayer
try
:
from
fairseq.model_parallel.megatron.mpu
import
(
ColumnParallelLinear
,
RowParallelLinear
,
)
has_megatron_submodule
=
True
except
(
ImportError
,
ModuleNotFoundError
):
has_megatron_submodule
=
False
class
ModelParallelTransformerEncoderLayer
(
TransformerEncoderLayer
):
"""Encoder layer block over multiple gpus.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
if
q_noise
>
0
:
raise
NotImplementedError
return
ColumnParallelLinear
(
input_dim
,
output_dim
,
gather_output
=
False
)
def
build_fc2
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
if
q_noise
>
0
:
raise
NotImplementedError
return
RowParallelLinear
(
input_dim
,
output_dim
,
input_is_parallel
=
True
)
def
build_self_attention
(
self
,
embed_dim
,
args
,
**
unused_kwargs
):
return
ModelParallelMultiheadAttention
(
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
True
,
)
class
ModelParallelTransformerDecoderLayer
(
TransformerDecoderLayer
):
"""Decoder layer block.
See "Megatron-LM: https://arxiv.org/pdf/1909.08053.pdf" for more details.
"""
def
build_fc1
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
if
q_noise
>
0
:
raise
NotImplementedError
return
ColumnParallelLinear
(
input_dim
,
output_dim
,
gather_output
=
False
)
def
build_fc2
(
self
,
input_dim
,
output_dim
,
q_noise
,
qn_block_size
):
if
q_noise
>
0
:
raise
NotImplementedError
return
RowParallelLinear
(
input_dim
,
output_dim
,
input_is_parallel
=
True
)
def
build_self_attention
(
self
,
embed_dim
,
args
,
**
unused_kwargs
):
return
ModelParallelMultiheadAttention
(
embed_dim
=
embed_dim
,
num_heads
=
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
self_attention
=
not
getattr
(
args
,
"cross_self_attention"
,
False
),
)
def
build_encoder_attention
(
self
,
embed_dim
,
args
,
**
unused_kwargs
):
return
ModelParallelMultiheadAttention
(
embed_dim
=
embed_dim
,
num_heads
=
args
.
decoder_attention_heads
,
kdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
vdim
=
getattr
(
args
,
"encoder_embed_dim"
,
None
),
dropout
=
args
.
attention_dropout
,
encoder_decoder_attention
=
True
,
)
Prev
1
…
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment