Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3177 deletions
+0
-3177
PyTorch/NLP/Transformer/fairseq/ddp_trainer.py
PyTorch/NLP/Transformer/fairseq/ddp_trainer.py
+0
-305
PyTorch/NLP/Transformer/fairseq/distributed_utils.py
PyTorch/NLP/Transformer/fairseq/distributed_utils.py
+0
-111
PyTorch/NLP/Transformer/fairseq/log_helper.py
PyTorch/NLP/Transformer/fairseq/log_helper.py
+0
-204
PyTorch/NLP/Transformer/fairseq/meters.py
PyTorch/NLP/Transformer/fairseq/meters.py
+0
-87
PyTorch/NLP/Transformer/fairseq/models/__init__.py
PyTorch/NLP/Transformer/fairseq/models/__init__.py
+0
-55
PyTorch/NLP/Transformer/fairseq/models/fairseq_incremental_decoder.py
...Transformer/fairseq/models/fairseq_incremental_decoder.py
+0
-42
PyTorch/NLP/Transformer/fairseq/models/fused_layer_norm.py
PyTorch/NLP/Transformer/fairseq/models/fused_layer_norm.py
+0
-159
PyTorch/NLP/Transformer/fairseq/models/transformer.py
PyTorch/NLP/Transformer/fairseq/models/transformer.py
+0
-621
PyTorch/NLP/Transformer/fairseq/modules/__init__.py
PyTorch/NLP/Transformer/fairseq/modules/__init__.py
+0
-18
PyTorch/NLP/Transformer/fairseq/modules/learned_positional_embedding.py
...ansformer/fairseq/modules/learned_positional_embedding.py
+0
-31
PyTorch/NLP/Transformer/fairseq/modules/multihead_attention.py
...ch/NLP/Transformer/fairseq/modules/multihead_attention.py
+0
-460
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm.cpp
...seq/modules/strided_batched_gemm/strided_batched_gemm.cpp
+0
-61
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm_cuda.cu
...modules/strided_batched_gemm/strided_batched_gemm_cuda.cu
+0
-345
PyTorch/NLP/Transformer/fairseq/optim/__init__.py
PyTorch/NLP/Transformer/fairseq/optim/__init__.py
+0
-46
PyTorch/NLP/Transformer/fairseq/optim/adam.py
PyTorch/NLP/Transformer/fairseq/optim/adam.py
+0
-54
PyTorch/NLP/Transformer/fairseq/optim/fairseq_optimizer.py
PyTorch/NLP/Transformer/fairseq/optim/fairseq_optimizer.py
+0
-94
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/__init__.py
...ch/NLP/Transformer/fairseq/optim/lr_scheduler/__init__.py
+0
-39
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/fixed_schedule.py
.../Transformer/fairseq/optim/lr_scheduler/fixed_schedule.py
+0
-57
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
...former/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
+0
-46
PyTorch/NLP/Transformer/fairseq/options.py
PyTorch/NLP/Transformer/fairseq/options.py
+0
-342
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/Transformer/fairseq/ddp_trainer.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Train a network across multiple GPUs.
"""
import
math
from
collections
import
defaultdict
from
itertools
import
chain
import
torch
import
torch.nn.functional
as
F
from
torch.cuda
import
amp
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
fairseq
import
distributed_utils
,
optim
,
utils
from
fairseq.optim
import
lr_scheduler
from
fairseq.meters
import
TimeMeter
,
AverageMeter
from
fairseq.criterions
import
CRITERION_REGISTRY
import
dllogger
as
DLLogger
class
DDPTrainer
():
"""Main class for data parallel training.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
"""
def
__init__
(
self
,
args
,
model
):
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
self
.
args
=
args
self
.
model
=
model
.
cuda
()
self
.
criterion
=
CRITERION_REGISTRY
[
args
.
criterion
](
args
).
cuda
()
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
self
.
model
.
parameters
())
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
self
.
scaler
=
amp
.
GradScaler
(
enabled
=
self
.
args
.
amp
,
init_scale
=
2
**
15
)
if
self
.
args
.
distributed_world_size
>
1
:
self
.
model
=
DDP
(
model
)
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
throughput_meter
=
TimeMeter
()
self
.
avg_loss_meter
=
AverageMeter
()
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
if
distributed_utils
.
is_master
(
self
.
args
):
# only save one checkpoint
utils
.
save_state
(
filename
,
self
.
args
,
self
.
get_model
(),
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
self
.
_num_updates
,
self
.
_optim_history
,
extra_state
,
)
def
load_checkpoint
(
self
,
filename
,
load_optim
=
True
):
"""Load all training state from a checkpoint file."""
extra_state
,
optim_history
,
last_optim_state
=
\
utils
.
load_model_state
(
filename
,
self
.
get_model
())
if
last_optim_state
is
not
None
:
# rebuild optimizer after loading model, since params may have changed
#self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
if
load_optim
:
self
.
_optim_history
=
optim_history
# only reload optimizer and lr_scheduler if they match
last_optim
=
self
.
_optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
self
.
criterion
.
__class__
.
__name__
:
self
.
lr_scheduler
.
load_state_dict
(
last_optim
[
'lr_scheduler_state'
])
if
last_optim
[
'optimizer_name'
]
==
self
.
optimizer
.
__class__
.
__name__
:
self
.
optimizer
.
load_state_dict
(
last_optim_state
)
self
.
_num_updates
=
last_optim
[
'num_updates'
]
return
extra_state
def
train_step
(
self
,
sample
,
update_params
=
True
,
last_step
=
False
):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed
=
self
.
args
.
seed
+
self
.
get_num_updates
()
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
self
.
model
.
train
()
if
isinstance
(
self
.
model
,
DDP
):
if
last_step
:
self
.
model
.
disable_allreduce
()
else
:
self
.
model
.
enable_allreduce
()
# forward and backward pass
sample
=
self
.
_prepare_sample
(
sample
)
loss
,
oom_fwd
=
self
.
_forward
(
sample
)
# If this is a last batch forward pass is skipped on some workers
# Batch with sample_size 0 is not accounted for in weighted loss
logging_output
=
{
'ntokens'
:
sample
[
'ntokens'
]
if
sample
is
not
None
else
0
,
'nsentences'
:
sample
[
'target'
].
size
(
0
)
if
sample
is
not
None
else
0
,
'loss'
:
utils
.
item
(
loss
.
data
)
if
loss
is
not
None
else
0
,
}
sample_size
=
sample
[
'ntokens'
]
if
sample
is
not
None
else
0
oom_bwd
=
self
.
_backward
(
loss
)
# buffer stats and logging outputs
self
.
_buffered_stats
[
'sample_sizes'
].
append
(
sample_size
)
self
.
_buffered_stats
[
'logging_outputs'
].
append
(
logging_output
)
self
.
_buffered_stats
[
'ooms_fwd'
].
append
(
oom_fwd
)
self
.
_buffered_stats
[
'ooms_bwd'
].
append
(
oom_bwd
)
# update parameters
if
update_params
and
not
last_step
:
# gather logging outputs from all replicas
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
ooms_bwd
=
self
.
_buffered_stats
[
'ooms_bwd'
]
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
=
map
(
lambda
l
:
list
(
chain
.
from_iterable
(
l
)),
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
)
))
)
ooms_fwd
=
sum
(
ooms_fwd
)
ooms_bwd
=
sum
(
ooms_bwd
)
ooms
=
ooms_fwd
+
ooms_bwd
# this is always <= distributed_world_size
if
ooms
==
self
.
args
.
distributed_world_size
:
print
(
'| WARNING: OOM in all workers, skipping batch'
)
self
.
zero_grad
()
return
# aggregate stats and logging outputs
grad_denom
=
sum
(
sample_sizes
)
for
p
in
self
.
model
.
parameters
():
if
p
.
requires_grad
and
p
.
grad
is
not
None
:
p
.
grad
/=
grad_denom
self
.
_opt
()
# Handle logging
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
self
.
throughput_meter
.
update
(
ntokens
)
info_log_data
=
{
'tokens/s'
:
self
.
throughput_meter
.
avg
,
'tokens'
:
ntokens
,
'loss'
:
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
/
ntokens
/
math
.
log
(
2
)
}
self
.
avg_loss_meter
.
update
(
info_log_data
[
'loss'
])
debug_log_data
=
{
'batch_size'
:
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
),
'lr'
:
self
.
get_lr
(),
'grad_denom'
:
grad_denom
,
'updates'
:
1
}
DLLogger
.
log
(
step
=
self
.
_num_updates
,
data
=
info_log_data
,
verbosity
=
0
)
DLLogger
.
log
(
step
=
self
.
_num_updates
,
data
=
debug_log_data
,
verbosity
=
1
)
self
.
clear_buffered_stats
()
def
_forward
(
self
,
sample
):
loss
=
None
oom
=
0
try
:
if
sample
is
not
None
:
with
amp
.
autocast
(
enabled
=
self
.
args
.
amp
):
# calculate loss and sample size
logits
,
_
=
self
.
model
(
**
sample
[
'net_input'
])
target
=
sample
[
'target'
]
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
loss
=
self
.
criterion
(
probs
,
target
)
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory in worker {}, skipping batch'
.
format
(
self
.
args
.
distributed_rank
),
force
=
True
)
oom
=
1
loss
=
None
else
:
raise
e
return
loss
,
oom
def
_backward
(
self
,
loss
):
oom
=
0
if
loss
is
not
None
:
try
:
self
.
scaler
.
scale
(
loss
).
backward
()
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory in worker {}, skipping batch'
.
format
(
self
.
args
.
distributed_rank
),
force
=
True
)
oom
=
1
self
.
zero_grad
()
else
:
raise
e
return
oom
def
_opt
(
self
):
# take an optimization step
self
.
scaler
.
step
(
self
.
optimizer
.
optimizer
)
self
.
scaler
.
update
()
self
.
zero_grad
()
self
.
_num_updates
+=
1
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
def
valid_step
(
self
,
sample
):
"""Do forward pass in evaluation mode."""
self
.
model
.
eval
()
# forward pass
sample
=
self
.
_prepare_sample
(
sample
)
with
torch
.
no_grad
():
loss
,
oom_fwd
=
self
.
_forward
(
sample
)
logging_output
=
{
'ntokens'
:
sample
[
'ntokens'
]
if
sample
is
not
None
else
0
,
'nsentences'
:
sample
[
'target'
].
size
(
0
)
if
sample
is
not
None
else
0
,
}
loss
=
loss
.
item
()
if
loss
is
not
None
else
0
assert
not
oom_fwd
,
'Ran out of memory during validation'
# gather logging outputs from all GPUs
if
self
.
args
.
distributed_world_size
>
1
:
losses
,
logging_outputs
=
zip
(
*
distributed_utils
.
all_gather_list
(
(
loss
,
logging_output
)
))
else
:
losses
=
[
loss
]
logging_outputs
=
[
logging_output
]
weight
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
scaled_loss
=
sum
(
losses
)
/
weight
/
math
.
log
(
2
)
return
scaled_loss
def
dummy_train_step
(
self
,
dummy_batch
):
"""Dummy training step for warming caching allocator."""
self
.
train_step
(
dummy_batch
,
update_params
=
False
)
self
.
zero_grad
()
self
.
clear_buffered_stats
()
def
zero_grad
(
self
):
self
.
optimizer
.
zero_grad
()
def
clear_buffered_stats
(
self
):
self
.
_buffered_stats
.
clear
()
def
lr_step
(
self
,
epoch
,
val_loss
=
None
):
"""Adjust the learning rate based on the validation loss."""
return
self
.
lr_scheduler
.
step
(
epoch
,
val_loss
)
def
lr_step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
return
self
.
lr_scheduler
.
step_update
(
num_updates
)
def
get_lr
(
self
):
"""Get the current learning rate."""
return
self
.
optimizer
.
get_lr
()
def
get_throughput_meter
(
self
):
"""Get the throughput meter"""
return
self
.
throughput_meter
def
get_model
(
self
):
"""Get the model replica."""
return
self
.
model
.
module
if
isinstance
(
self
.
model
,
DDP
)
else
self
.
model
def
get_num_updates
(
self
):
"""Get the number of parameters updates."""
return
self
.
_num_updates
def
_prepare_sample
(
self
,
sample
):
if
not
sample
:
return
None
return
utils
.
move_to_cuda
(
sample
)
PyTorch/NLP/Transformer/fairseq/distributed_utils.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pickle
import
os
import
socket
import
torch.distributed
from
fairseq
import
utils
def
is_master
(
args
):
return
args
.
distributed_rank
==
0
def
distributed_init
(
args
):
if
args
.
distributed_world_size
==
1
:
raise
ValueError
(
'Cannot initialize distributed with distributed_world_size=1'
)
print
(
'| distributed init (rank {}): {}'
.
format
(
args
.
distributed_rank
,
args
.
distributed_init_method
),
flush
=
True
)
print
(
"| distributed env init. MASTER_ADDR: "
+
os
.
environ
[
'MASTER_ADDR'
]
+
", MASTER_PORT: "
+
os
.
environ
[
'MASTER_PORT'
]
+
", WORLD_SIZE: "
+
os
.
environ
[
'WORLD_SIZE'
]
+
", RANK: "
+
os
.
environ
[
'RANK'
],
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
init_method
=
'env://'
)
print
(
"| distributed init done!"
,
flush
=
True
)
args
.
distributed_world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
args
.
distributed_rank
=
torch
.
distributed
.
get_rank
()
args
.
device_id
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
args
.
local_rank
))
suppress_output
(
args
)
print
(
'| initialized host {} as rank {} and device id {}'
.
format
(
socket
.
gethostname
(),
args
.
distributed_rank
,
args
.
device_id
))
return
args
.
distributed_rank
def
suppress_output
(
main_args
):
"""Suppress printing on the current device. Force printing with `force=True`."""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print_master
(
*
args
,
**
kwargs
):
if
'force'
in
kwargs
:
kwargs
.
pop
(
'force'
)
builtin_print
(
*
args
,
**
kwargs
)
def
print
(
*
args
,
**
kwargs
):
if
'force'
in
kwargs
:
force
=
kwargs
.
pop
(
'force'
)
if
force
:
builtin_print
(
*
args
,
**
kwargs
)
if
is_master
(
main_args
):
__builtin__
.
print
=
print_master
else
:
__builtin__
.
print
=
print
def
all_gather_list
(
data
,
max_size
=
16384
):
"""Gathers arbitrary data from all nodes into a list."""
world_size
=
torch
.
distributed
.
get_world_size
()
if
not
hasattr
(
all_gather_list
,
'_in_buffer'
)
or
\
max_size
!=
len
(
all_gather_list
.
_in_buffer
):
all_gather_list
.
_in_buffer
=
torch
.
cuda
.
ByteTensor
(
max_size
)
all_gather_list
.
_out_buffers
=
[
torch
.
cuda
.
ByteTensor
(
max_size
)
for
i
in
range
(
world_size
)
]
in_buffer
=
all_gather_list
.
_in_buffer
out_buffers
=
all_gather_list
.
_out_buffers
enc
=
pickle
.
dumps
(
data
)
enc_size
=
len
(
enc
)
if
enc_size
+
2
>
max_size
:
raise
ValueError
(
'encoded data exceeds max_size: {}'
.
format
(
enc_size
+
2
))
assert
max_size
<
255
*
256
in_buffer
[
0
]
=
enc_size
//
255
# this encoding works for max_size < 65k
in_buffer
[
1
]
=
enc_size
%
255
in_buffer
[
2
:
enc_size
+
2
]
=
torch
.
ByteTensor
(
list
(
enc
))
torch
.
distributed
.
all_gather
(
out_buffers
,
in_buffer
.
cuda
())
result
=
[]
for
i
in
range
(
world_size
):
out_buffer
=
out_buffers
[
i
]
size
=
(
255
*
utils
.
item
(
out_buffer
[
0
]))
+
utils
.
item
(
out_buffer
[
1
])
result
.
append
(
pickle
.
loads
(
bytes
(
out_buffer
[
2
:
size
+
2
].
tolist
()))
)
return
result
PyTorch/NLP/Transformer/fairseq/log_helper.py
deleted
100644 → 0
View file @
c056df78
import
os
import
atexit
import
time
import
itertools
from
collections
import
OrderedDict
import
dllogger
from
dllogger
import
Backend
,
JSONStreamBackend
from
tensorboardX
import
SummaryWriter
class
AverageMeter
():
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
updated
=
False
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
value
):
self
.
updated
=
True
if
isinstance
(
value
,
(
tuple
,
list
)):
val
=
value
[
0
]
n
=
value
[
1
]
else
:
val
=
value
n
=
1
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
@
property
def
value
(
self
):
return
self
.
avg
class
PerformanceMeter
():
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
updated
=
False
self
.
start
=
time
.
time
()
self
.
n
=
0
def
update
(
self
,
val
=
1
):
self
.
updated
=
True
self
.
n
+=
val
@
property
def
value
(
self
):
return
self
.
n
/
self
.
elapsed_time
@
property
def
elapsed_time
(
self
):
return
time
.
time
()
-
self
.
start
METRIC
=
{
'average'
:
AverageMeter
,
'performance'
:
PerformanceMeter
}
class
AggregatorBackend
(
Backend
):
def
__init__
(
self
,
verbosity
,
agg_dict
):
super
().
__init__
(
verbosity
=
verbosity
)
agg_dict
=
OrderedDict
({
k
:
v
if
isinstance
(
v
,
(
tuple
,
list
))
else
(
v
,)
for
k
,
v
in
agg_dict
.
items
()})
self
.
metrics
=
OrderedDict
({
k
:
[
METRIC
[
x
]()
for
x
in
v
]
for
k
,
v
in
agg_dict
.
items
()})
self
.
metrics
.
flushed
=
True
self
.
step
=
0
self
.
epoch
=
0
self
.
start_time
=
time
.
time
()
@
property
def
log_level
(
self
):
return
self
.
_log_level
def
metadata
(
self
,
timestamp
,
elapsedtime
,
metric
,
metadata
):
pass
def
_reset_perf_meter
(
self
,
name
):
for
agg
in
self
.
metrics
[
name
]:
if
isinstance
(
agg
,
PerformanceMeter
):
agg
.
reset
()
def
reset_perf_meters
(
self
):
for
name
in
self
.
metrics
.
keys
():
self
.
_reset_perf_meter
(
name
)
def
log
(
self
,
timestamp
,
elapsedtime
,
step
,
data
):
self
.
step
=
step
if
'epoch'
in
data
.
keys
():
self
.
epoch
=
data
[
'epoch'
]
for
k
,
v
in
data
.
items
():
if
k
not
in
self
.
metrics
.
keys
():
continue
self
.
metrics
.
flushed
=
False
for
ag
in
self
.
metrics
[
k
]:
ag
.
update
(
v
)
def
flush
(
self
):
if
self
.
metrics
.
flushed
:
return
result_string
=
'Transformer | epoch {} | step {} |'
.
format
(
self
.
epoch
,
self
.
step
)
for
name
,
aggregators
in
self
.
metrics
.
items
():
for
agg
in
aggregators
:
if
not
agg
.
updated
:
continue
if
isinstance
(
agg
,
AverageMeter
):
_name
=
'avg '
+
name
elif
isinstance
(
agg
,
PerformanceMeter
):
_name
=
name
+
'/s'
result_string
+=
_name
+
' {:.3f} |'
.
format
(
agg
.
value
)
agg
.
reset
()
result_string
+=
'walltime {:.3f} |'
.
format
(
time
.
time
()
-
self
.
start_time
)
self
.
metrics
.
flushed
=
True
print
(
result_string
)
class
TensorBoardBackend
(
Backend
):
def
__init__
(
self
,
verbosity
,
log_dir
):
super
().
__init__
(
verbosity
=
verbosity
)
self
.
summary_writer
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
log_dir
,
'TB_summary'
),
flush_secs
=
120
,
max_queue
=
200
)
atexit
.
register
(
self
.
summary_writer
.
close
)
@
property
def
log_level
(
self
):
return
self
.
_log_level
def
metadata
(
self
,
timestamp
,
elapsedtime
,
metric
,
metadata
):
pass
def
log
(
self
,
timestamp
,
elapsedtime
,
step
,
data
):
if
not
isinstance
(
step
,
int
):
return
for
k
,
v
in
data
.
items
():
self
.
summary_writer
.
add_scalar
(
k
,
v
,
step
)
def
flush
(
self
):
pass
def
setup_logger
(
args
):
aggregator_dict
=
OrderedDict
([
(
'loss'
,
'average'
),
(
'weighted_loss'
,
'average'
),
(
'tokens'
,
(
'average'
,
'performance'
)),
(
'updates'
,
'performance'
),
(
'gnorm'
,
'average'
)
])
os
.
makedirs
(
args
.
save_dir
,
exist_ok
=
True
)
log_path
=
os
.
path
.
join
(
args
.
save_dir
,
args
.
stat_file
)
if
os
.
path
.
exists
(
log_path
):
for
i
in
itertools
.
count
():
s_fname
=
args
.
stat_file
.
split
(
'.'
)
fname
=
'.'
.
join
(
s_fname
[:
-
1
])
+
f
'_
{
i
}
.'
+
s_fname
[
-
1
]
if
len
(
s_fname
)
>
1
else
args
.
stat_file
+
f
'.
{
i
}
'
log_path
=
os
.
path
.
join
(
args
.
save_dir
,
fname
)
if
not
os
.
path
.
exists
(
log_path
):
break
if
not
args
.
distributed_world_size
>
1
or
args
.
distributed_rank
==
0
:
dllogger
.
init
(
backends
=
[
JSONStreamBackend
(
verbosity
=
1
,
filename
=
log_path
),
AggregatorBackend
(
verbosity
=
0
,
agg_dict
=
aggregator_dict
),
TensorBoardBackend
(
verbosity
=
1
,
log_dir
=
args
.
save_dir
)])
else
:
dllogger
.
init
(
backends
=
[])
for
k
,
v
in
vars
(
args
).
items
():
dllogger
.
log
(
step
=
'PARAMETER'
,
data
=
{
k
:
v
},
verbosity
=
0
)
container_setup_info
=
get_framework_env_vars
()
dllogger
.
log
(
step
=
'PARAMETER'
,
data
=
container_setup_info
,
verbosity
=
0
)
dllogger
.
metadata
(
'loss'
,
{
'unit'
:
'nat'
,
'GOAL'
:
'MINIMIZE'
,
'STAGE'
:
'TRAIN'
})
dllogger
.
metadata
(
'val_loss'
,
{
'unit'
:
'nat'
,
'GOAL'
:
'MINIMIZE'
,
'STAGE'
:
'VAL'
})
dllogger
.
metadata
(
'speed'
,
{
'unit'
:
'tokens/s'
,
'format'
:
':.3f'
,
'GOAL'
:
'MAXIMIZE'
,
'STAGE'
:
'TRAIN'
})
dllogger
.
metadata
(
'accuracy'
,
{
'unit'
:
'bleu'
,
'format'
:
':.2f'
,
'GOAL'
:
'MAXIMIZE'
,
'STAGE'
:
'VAL'
})
def
get_framework_env_vars
():
return
{
'NVIDIA_PYTORCH_VERSION'
:
os
.
environ
.
get
(
'NVIDIA_PYTORCH_VERSION'
),
'PYTORCH_VERSION'
:
os
.
environ
.
get
(
'PYTORCH_VERSION'
),
'CUBLAS_VERSION'
:
os
.
environ
.
get
(
'CUBLAS_VERSION'
),
'NCCL_VERSION'
:
os
.
environ
.
get
(
'NCCL_VERSION'
),
'CUDA_DRIVER_VERSION'
:
os
.
environ
.
get
(
'CUDA_DRIVER_VERSION'
),
'CUDNN_VERSION'
:
os
.
environ
.
get
(
'CUDNN_VERSION'
),
'CUDA_VERSION'
:
os
.
environ
.
get
(
'CUDA_VERSION'
),
'NVIDIA_PIPELINE_ID'
:
os
.
environ
.
get
(
'NVIDIA_PIPELINE_ID'
),
'NVIDIA_BUILD_ID'
:
os
.
environ
.
get
(
'NVIDIA_BUILD_ID'
),
'NVIDIA_TF32_OVERRIDE'
:
os
.
environ
.
get
(
'NVIDIA_TF32_OVERRIDE'
),
}
def
reset_perf_meters
():
for
backend
in
dllogger
.
GLOBAL_LOGGER
.
backends
:
if
isinstance
(
backend
,
AggregatorBackend
):
backend
.
reset_perf_meters
()
PyTorch/NLP/Transformer/fairseq/meters.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
time
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
class
TimeMeter
(
object
):
"""Computes the average occurrence of some event per second"""
def
__init__
(
self
,
init
=
0
):
self
.
reset
(
init
)
def
reset
(
self
,
init
=
0
):
self
.
init
=
init
self
.
start
=
time
.
time
()
self
.
n
=
0
self
.
last_update
=
time
.
time
()
def
update
(
self
,
val
=
1
):
self
.
n
+=
val
self
.
last_update
=
time
.
time
()
@
property
def
avg
(
self
):
return
self
.
n
/
self
.
elapsed_time
@
property
def
elapsed_time
(
self
):
return
self
.
init
+
(
time
.
time
()
-
self
.
start
)
@
property
def
u_avg
(
self
):
return
self
.
n
/
(
self
.
last_update
-
self
.
start
)
class
StopwatchMeter
(
object
):
"""Computes the sum/avg duration of some event in seconds"""
def
__init__
(
self
):
self
.
reset
()
self
.
intervals
=
[]
def
start
(
self
):
self
.
start_time
=
time
.
time
()
def
stop
(
self
,
n
=
1
):
if
self
.
start_time
is
not
None
:
delta
=
time
.
time
()
-
self
.
start_time
self
.
intervals
.
append
(
delta
)
self
.
sum
+=
delta
self
.
n
+=
n
self
.
start_time
=
None
def
reset
(
self
):
self
.
sum
=
0
self
.
n
=
0
self
.
start_time
=
None
self
.
intervals
=
[]
@
property
def
avg
(
self
):
return
self
.
sum
/
self
.
n
def
p
(
self
,
i
):
assert
i
<=
100
idx
=
int
(
len
(
self
.
intervals
)
*
i
/
100
)
return
sorted
(
self
.
intervals
)[
idx
]
PyTorch/NLP/Transformer/fairseq/models/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_incremental_decoder
import
FairseqIncrementalDecoder
# noqa: F401
MODEL_REGISTRY
=
{}
ARCH_MODEL_REGISTRY
=
{}
ARCH_CONFIG_REGISTRY
=
{}
def
build_model
(
args
):
return
ARCH_MODEL_REGISTRY
[
args
.
arch
].
build_model
(
args
)
def
register_model
(
name
):
"""Decorator to register a new model (e.g., LSTM)."""
def
register_model_cls
(
cls
):
if
name
in
MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate model ({})'
.
format
(
name
))
MODEL_REGISTRY
[
name
]
=
cls
return
cls
return
register_model_cls
def
register_model_architecture
(
model_name
,
arch_name
):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
def
register_model_arch_fn
(
fn
):
if
model_name
not
in
MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register model architecture for unknown model type ({})'
.
format
(
model_name
))
if
arch_name
in
ARCH_MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate model architecture ({})'
.
format
(
arch_name
))
if
not
callable
(
fn
):
raise
ValueError
(
'Model architecture must be callable ({})'
.
format
(
arch_name
))
ARCH_MODEL_REGISTRY
[
arch_name
]
=
MODEL_REGISTRY
[
model_name
]
ARCH_CONFIG_REGISTRY
[
arch_name
]
=
fn
return
fn
return
register_model_arch_fn
# automatically import any Python files in the models/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.models.'
+
module
)
PyTorch/NLP/Transformer/fairseq/models/fairseq_incremental_decoder.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.nn
as
nn
class
FairseqIncrementalDecoder
(
nn
.
Module
):
"""Base class for incremental decoders."""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
raise
NotImplementedError
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
def
apply_reorder_incremental_state
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'reorder_incremental_state'
):
module
.
reorder_incremental_state
(
incremental_state
,
new_order
,
)
self
.
apply
(
apply_reorder_incremental_state
)
def
set_beam_size
(
self
,
beam_size
):
"""Sets the beam size in the decoder and all children."""
if
getattr
(
self
,
'_beam_size'
,
-
1
)
!=
beam_size
:
def
apply_set_beam_size
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'set_beam_size'
):
module
.
set_beam_size
(
beam_size
)
self
.
apply
(
apply_set_beam_size
)
self
.
_beam_size
=
beam_size
PyTorch/NLP/Transformer/fairseq/models/fused_layer_norm.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import math
import
math
import
torch
import
numbers
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
import
fused_layer_norm_cuda
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-6
):
self
.
normalized_shape
=
normalized_shape
self
.
eps
=
eps
def
forward
(
self
,
input
,
weight
,
bias
):
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward_affine
(
input_
,
self
.
normalized_shape
,
weight_
,
bias_
,
self
.
eps
)
self
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
def
backward
(
self
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
self
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
self
.
normalized_shape
,
weight_
,
bias_
,
self
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
;
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-6
):
self
.
normalized_shape
=
normalized_shape
self
.
eps
=
eps
def
forward
(
self
,
input
):
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
self
.
normalized_shape
,
self
.
eps
)
self
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
def
backward
(
self
,
grad_output
):
input_
,
mean
,
invvar
=
self
.
saved_tensors
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
self
.
normalized_shape
,
self
.
eps
)
return
grad_input
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
(
normalized_shape
,
eps
)(
input
,
weight
,
bias
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
(
normalized_shape
,
eps
)(
input
)
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = nn.LayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = nn.LayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
super
(
FusedLayerNorm
,
self
).
__init__
()
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
elementwise_affine
=
elementwise_affine
if
self
.
elementwise_affine
:
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
else
:
self
.
register_parameter
(
'weight'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
if
self
.
elementwise_affine
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
self
.
elementwise_affine
:
return
FusedLayerNormAffineFunction
(
self
.
normalized_shape
,
self
.
eps
)(
input
,
self
.
weight
,
self
.
bias
)
else
:
return
FusedLayerNormFunction
(
self
.
normalized_shape
,
self
.
eps
)(
input
)
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
'elementwise_affine={elementwise_affine}'
.
format
(
**
self
.
__dict__
)
PyTorch/NLP/Transformer/fairseq/models/transformer.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
typing
import
Optional
,
Dict
from
fairseq.modules
import
(
LearnedPositionalEmbedding
,
MultiheadAttention
,
SinusoidalPositionalEmbedding
)
from
.
import
(
FairseqIncrementalDecoder
,
register_model
,
register_model_architecture
,
)
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
torch
.
set_printoptions
(
threshold
=
500000000
,
linewidth
=
1024
)
@
torch
.
jit
.
script
def
jit_dropout_add
(
x
,
residual
,
prob
,
is_training
):
# type: (Tensor, Tensor, float, bool) -> Tensor
out
=
F
.
dropout
(
x
,
p
=
prob
,
training
=
is_training
)
out
=
residual
+
out
return
out
@
torch
.
jit
.
script
def
jit_relu_dropout
(
x
,
prob
,
is_training
):
# type: (Tensor, float, bool) -> Tensor
out
=
F
.
threshold
(
x
,
0.
,
0.
)
out
=
F
.
dropout
(
out
,
p
=
prob
,
training
=
is_training
)
return
out
@
register_model
(
'transformer'
)
class
TransformerModel
(
nn
.
Module
):
@
staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability'
)
parser
.
add_argument
(
'--attention-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability for attention weights'
)
parser
.
add_argument
(
'--relu-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability after ReLU in FFN'
)
parser
.
add_argument
(
'--encoder-embed-path'
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained encoder embedding'
)
parser
.
add_argument
(
'--encoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension'
)
parser
.
add_argument
(
'--encoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--encoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder layers'
)
parser
.
add_argument
(
'--encoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder attention heads'
)
parser
.
add_argument
(
'--encoder-normalize-before'
,
action
=
'store_true'
,
help
=
'apply layernorm before each encoder block'
)
parser
.
add_argument
(
'--encoder-learned-pos'
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the encoder'
)
parser
.
add_argument
(
'--decoder-embed-path'
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained decoder embedding'
)
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension'
)
parser
.
add_argument
(
'--decoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--decoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder layers'
)
parser
.
add_argument
(
'--decoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder attention heads'
)
parser
.
add_argument
(
'--decoder-learned-pos'
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the decoder'
)
parser
.
add_argument
(
'--decoder-normalize-before'
,
action
=
'store_true'
,
help
=
'apply layernorm before each decoder block'
)
parser
.
add_argument
(
'--share-decoder-input-output-embed'
,
action
=
'store_true'
,
help
=
'share decoder input and output embeddings'
)
parser
.
add_argument
(
'--share-all-embeddings'
,
action
=
'store_true'
,
help
=
'share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)'
)
def
__init__
(
self
,
encoder
,
decoder
):
super
().
__init__
()
self
.
_is_generation_fast
=
False
self
.
encoder
=
encoder
self
.
decoder
=
decoder
@
classmethod
def
build_model
(
cls
,
args
):
# make sure all arguments are present in older models
base_architecture
(
args
)
if
not
hasattr
(
args
,
'max_source_positions'
):
args
.
max_source_positions
=
1024
if
not
hasattr
(
args
,
'max_target_positions'
):
args
.
max_target_positions
=
1024
if
args
.
share_all_embeddings
:
if
args
.
src_vocab_size
!=
args
.
tgt_vocab_size
:
raise
RuntimeError
(
'--share-all-embeddings requires a joined dictionary'
)
if
args
.
encoder_embed_dim
!=
args
.
decoder_embed_dim
:
raise
RuntimeError
(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
)
if
args
.
decoder_embed_path
and
(
args
.
decoder_embed_path
!=
args
.
encoder_embed_path
):
raise
RuntimeError
(
'--share-all-embeddings not compatible with --decoder-embed-path'
)
encoder_embed_tokens
=
Embedding
(
args
.
src_vocab_size
,
args
.
encoder_embed_dim
,
args
.
padding_idx
)
decoder_embed_tokens
=
encoder_embed_tokens
args
.
share_decoder_input_output_embed
=
True
else
:
encoder_embed_tokens
=
Embedding
(
args
.
src_vocab_size
,
args
.
encoder_embed_dim
,
args
.
padding_idx
)
decoder_embed_tokens
=
Embedding
(
args
.
tgt_vocab_size
,
args
.
decoder_embed_dim
,
args
.
padding_idx
)
encoder
=
TransformerEncoder
(
args
,
encoder_embed_tokens
)
decoder
=
TransformerDecoder
(
args
,
decoder_embed_tokens
)
return
TransformerModel
(
encoder
,
decoder
)
def
make_generation_fast_
(
self
,
**
kwargs
):
"""Optimize model for faster generation."""
if
self
.
_is_generation_fast
:
return
# only apply once
self
.
_is_generation_fast
=
True
# remove weight norm from all modules in the network
def
apply_remove_weight_norm
(
module
):
try
:
nn
.
utils
.
remove_weight_norm
(
module
)
except
ValueError
:
# this module didn't have weight norm
return
self
.
apply
(
apply_remove_weight_norm
)
def
apply_make_generation_fast_
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'make_generation_fast_'
):
module
.
make_generation_fast_
(
**
kwargs
)
self
.
apply
(
apply_make_generation_fast_
)
def
train
(
mode
):
if
mode
:
raise
RuntimeError
(
'cannot train after make_generation_fast'
)
# this model should no longer be used for training
self
.
eval
()
self
.
train
=
train
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
):
encoder_out
,
padding_mask
=
self
.
encoder
(
src_tokens
,
src_lengths
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
,
encoder_out
,
padding_mask
)
return
decoder_out
class
TransformerEncoder
(
nn
.
Module
):
"""Transformer encoder."""
def
__init__
(
self
,
args
,
embed_tokens
,
left_pad
=
True
):
super
().
__init__
()
self
.
dropout
=
args
.
dropout
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
embed_dim
=
embed_tokens
.
embedding_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_source_positions
=
args
.
max_source_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
,
left_pad
=
left_pad
,
learned
=
args
.
encoder_learned_pos
,
)
if
not
args
.
no_token_positional_embeddings
else
None
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
TransformerEncoderLayer
(
args
)
for
i
in
range
(
args
.
encoder_layers
)
])
self
.
normalize
=
args
.
encoder_normalize_before
if
self
.
normalize
:
self
.
layer_norm
=
FusedLayerNorm
(
embed_dim
)
if
args
.
fuse_layer_norm
else
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
src_tokens
)
if
self
.
embed_positions
is
not
None
:
x
+=
self
.
embed_positions
(
src_tokens
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
# The tensor needs to copy transposed because
# fused dropout is not capable of handing strided data
if
self
.
fuse_dropout_add
:
x
=
x
.
transpose
(
0
,
1
).
contiguous
()
else
:
x
=
x
.
transpose
(
0
,
1
)
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
if
not
encoder_padding_mask
.
any
():
_encoder_padding_mask
=
None
else
:
_encoder_padding_mask
=
encoder_padding_mask
# encoder layers
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
_encoder_padding_mask
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
return
x
,
encoder_padding_mask
# x.shape == T x B x C, encoder_padding_mask.shape == B x T
def
reorder_encoder_out
(
self
,
encoder_out
,
encoder_padding_mask
,
new_order
):
if
encoder_out
is
not
None
:
encoder_out
=
encoder_out
.
index_select
(
1
,
new_order
)
if
encoder_padding_mask
is
not
None
:
encoder_padding_mask
=
encoder_padding_mask
.
index_select
(
0
,
new_order
)
return
encoder_out
,
encoder_padding_mask
class
TransformerDecoder
(
FairseqIncrementalDecoder
):
"""Transformer decoder."""
def
__init__
(
self
,
args
,
embed_tokens
,
no_encoder_attn
=
False
,
left_pad
=
False
):
super
().
__init__
()
self
.
dropout
=
args
.
dropout
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
embed_dim
=
embed_tokens
.
embedding_dim
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_target_positions
=
args
.
max_target_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_target_positions
,
embed_dim
,
padding_idx
,
left_pad
=
left_pad
,
learned
=
args
.
decoder_learned_pos
,
)
if
not
args
.
no_token_positional_embeddings
else
None
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
TransformerDecoderLayer
(
args
,
no_encoder_attn
)
for
_
in
range
(
args
.
decoder_layers
)
])
if
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
args
.
tgt_vocab_size
,
embed_dim
))
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
else
:
self
.
embed_out
=
self
.
embed_tokens
.
weight
self
.
normalize
=
args
.
decoder_normalize_before
if
self
.
normalize
:
self
.
layer_norm
=
FusedLayerNorm
(
embed_dim
)
if
args
.
fuse_layer_norm
else
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
prev_output_tokens
:
Tensor
,
encoder_out
:
Tensor
,
encoder_padding_mask
:
Tensor
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]]
=
None
):
# embed positions
positions
=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
=
incremental_state
,
)
if
self
.
embed_positions
is
not
None
else
None
if
incremental_state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
if
positions
is
not
None
:
positions
=
positions
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
positions
is
not
None
:
x
+=
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
# The tensor needs to copy transposed because
# fused dropout is not capable of handing strided data
if
self
.
fuse_dropout_add
:
x
=
x
.
transpose
(
0
,
1
).
contiguous
()
else
:
x
=
x
.
transpose
(
0
,
1
)
attn
=
None
# decoder layers
for
layer
in
self
.
layers
:
x
,
attn
=
layer
(
x
,
encoder_out
,
encoder_padding_mask
if
encoder_padding_mask
.
any
()
else
None
,
incremental_state
,
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
# project back to size of vocabulary
x
=
F
.
linear
(
x
,
self
.
embed_out
)
return
x
,
attn
class
TransformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
dropout
=
args
.
dropout
self
.
relu_dropout
=
args
.
relu_dropout
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
maybe_ln1
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
self
.
maybe_ln2
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
def
forward
(
self
,
x
:
Tensor
,
encoder_padding_mask
:
Optional
[
Tensor
]):
residual
=
x
x
=
self
.
maybe_ln1
(
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
mask_future_timesteps
=
False
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
None
,
need_weights
=
False
,
static_kv
=
False
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_ln1
(
x
,
after
=
True
)
residual
=
x
x
=
self
.
maybe_ln2
(
x
,
before
=
True
)
if
self
.
fuse_relu_dropout
:
x
=
jit_relu_dropout
(
self
.
fc1
(
x
),
self
.
relu_dropout
,
self
.
training
)
else
:
x
=
F
.
threshold
(
self
.
fc1
(
x
),
0.0
,
0.0
)
x
=
F
.
dropout
(
x
,
p
=
self
.
relu_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_ln2
(
x
,
after
=
True
)
return
x
class
TransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block."""
def
__init__
(
self
,
args
,
no_encoder_attn
=
False
):
super
().
__init__
()
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
dropout
=
args
.
dropout
self
.
relu_dropout
=
args
.
relu_dropout
self
.
normalize_before
=
args
.
decoder_normalize_before
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
self
.
self_attn_layer_norm
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
if
no_encoder_attn
:
self
.
encoder_attn
=
None
self
.
encoder_attn_layer_norm
=
None
else
:
self
.
encoder_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
encoder_attn_layer_norm
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
self
.
need_attn
=
True
def
forward
(
self
,
x
:
Tensor
,
encoder_out
:
Tensor
,
encoder_padding_mask
:
Optional
[
Tensor
],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]]):
residual
=
x
x
=
self
.
self_attn_layer_norm
(
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
mask_future_timesteps
=
True
,
key_padding_mask
=
None
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
static_kv
=
False
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
,
after
=
True
)
attn
=
None
if
self
.
encoder_attn
is
not
None
:
residual
=
x
x
=
self
.
encoder_attn_layer_norm
(
x
,
before
=
True
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
mask_future_timesteps
=
False
,
need_weights
=
(
not
self
.
training
and
self
.
need_attn
),
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
encoder_attn_layer_norm
(
x
,
after
=
True
)
residual
=
x
x
=
self
.
final_layer_norm
(
x
,
before
=
True
)
if
self
.
fuse_relu_dropout
:
x
=
jit_relu_dropout
(
self
.
fc1
(
x
),
self
.
relu_dropout
,
self
.
training
)
else
:
x
=
F
.
threshold
(
self
.
fc1
(
x
),
0.0
,
0.0
)
x
=
F
.
dropout
(
x
,
p
=
self
.
relu_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
final_layer_norm
(
x
,
after
=
True
)
return
x
,
attn
def
make_generation_fast_
(
self
,
need_attn
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
class
MaybeLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
normalize_before
,
fuse
=
True
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
normalize_before
=
normalize_before
self
.
ln
=
FusedLayerNorm
(
embed_dim
)
if
fuse
else
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
x
:
Tensor
,
before
:
bool
=
False
,
after
:
bool
=
False
):
assert
before
^
after
if
after
^
self
.
normalize_before
:
return
self
.
ln
(
x
)
else
:
return
x
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
,
learned
=
False
):
if
learned
:
m
=
LearnedPositionalEmbedding
(
num_embeddings
+
padding_idx
+
1
,
embedding_dim
,
padding_idx
,
left_pad
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
else
:
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
left_pad
,
num_embeddings
+
padding_idx
+
1
)
return
m
@
register_model_architecture
(
'transformer'
,
'transformer'
)
def
base_architecture
(
args
):
args
.
encoder_embed_path
=
getattr
(
args
,
'encoder_embed_path'
,
None
)
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
6
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
False
)
args
.
encoder_learned_pos
=
getattr
(
args
,
'encoder_learned_pos'
,
False
)
args
.
decoder_embed_path
=
getattr
(
args
,
'decoder_embed_path'
,
None
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
6
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
'decoder_normalize_before'
,
False
)
args
.
decoder_learned_pos
=
getattr
(
args
,
'decoder_learned_pos'
,
False
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.
)
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.1
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
'share_decoder_input_output_embed'
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
'share_all_embeddings'
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
'no_token_positional_embeddings'
,
False
)
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
def
transformer_iwslt_de_en
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
1024
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
4
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
6
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
512
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
1024
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
4
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
6
)
base_architecture
(
args
)
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de'
)
def
transformer_wmt_en_de
(
args
):
base_architecture
(
args
)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@
register_model_architecture
(
'transformer'
,
'transformer_vaswani_wmt_en_de_big'
)
def
transformer_vaswani_wmt_en_de_big
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
4096
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
16
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
False
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
1024
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
4096
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
16
)
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.3
)
base_architecture
(
args
)
@
register_model_architecture
(
'transformer'
,
'transformer_vaswani_wmt_en_fr_big'
)
def
transformer_vaswani_wmt_en_fr_big
(
args
):
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big'
)
def
transformer_wmt_en_de_big
(
args
):
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
# default parameters used in tensor2tensor implementation
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big_t2t'
)
def
transformer_wmt_en_de_big_t2t
(
args
):
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
True
)
args
.
decoder_normalize_before
=
getattr
(
args
,
'decoder_normalize_before'
,
True
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
PyTorch/NLP/Transformer/fairseq/modules/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.beamable_mm
import
BeamableMM
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.multihead_attention
import
MultiheadAttention
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
__all__
=
[
'BeamableMM'
,
'LearnedPositionalEmbedding'
,
'MultiheadAttention'
,
'SinusoidalPositionalEmbedding'
,
]
PyTorch/NLP/Transformer/fairseq/modules/learned_positional_embedding.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.nn
as
nn
from
fairseq
import
utils
class
LearnedPositionalEmbedding
(
nn
.
Embedding
):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
left_pad
=
left_pad
def
forward
(
self
,
input
,
incremental_state
=
None
):
"""Input is expected to be of size [bsz x seqlen]."""
if
incremental_state
is
not
None
:
# positions is the same for every token when decoding a single step
positions
=
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
))
else
:
positions
=
utils
.
make_positions
(
input
.
data
,
self
.
padding_idx
,
self
.
left_pad
)
return
super
().
forward
(
positions
)
PyTorch/NLP/Transformer/fairseq/modules/multihead_attention.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
,
Optional
import
torch
from
torch
import
nn
,
Tensor
from
torch.nn
import
Parameter
import
torch.nn.functional
as
F
from
torch.cuda
import
amp
from
torch.autograd.variable
import
Variable
import
strided_batched_gemm
from
fairseq
import
utils
class
QueryLinear
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input
,
weights_q
,
scale
):
s
=
Variable
(
torch
.
tensor
([
scale
]))
ctx
.
save_for_backward
(
input
,
weights_q
,
s
)
q
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_q
,
beta
=
0.0
,
alpha
=
s
[
0
])
q
=
q
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
return
q
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
q_grad
):
input
,
weights_q
,
s
=
ctx
.
saved_tensors
input
=
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)).
transpose
(
0
,
1
)
q
=
torch
.
addmm
(
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
weights_q
.
transpose
(
0
,
1
),
beta
=
0.0
,
alpha
=
s
[
0
])
q
=
q
.
view
(
q_grad
.
size
(
0
),
q_grad
.
size
(
1
),
q_grad
.
size
(
2
))
q_grad
=
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
))
weights_q_grad
=
torch
.
addmm
(
weights_q
,
input
,
q_grad
,
beta
=
0.0
,
alpha
=
s
[
0
])
return
q
,
weights_q_grad
,
None
class
KeyValueLinears
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input
,
weights_k
,
weights_v
):
ctx
.
save_for_backward
(
input
,
weights_k
,
weights_v
)
k
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_k
,
beta
=
0.0
,
alpha
=
1.0
)
k
=
k
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
v
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_v
,
beta
=
0.0
,
alpha
=
1.0
)
v
=
v
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
return
k
.
detach
(),
v
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
k_grad
,
v_grad
):
input
,
weights_k
,
weights_v
=
ctx
.
saved_tensors
input
=
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)).
transpose
(
0
,
1
)
k
=
torch
.
addmm
(
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
)),
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
)),
weights_k
.
transpose
(
0
,
1
),
beta
=
0.0
)
k_grad
=
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
))
weights_k_grad
=
torch
.
mm
(
input
,
k_grad
)
v
=
k
.
addmm_
(
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
)),
weights_v
.
transpose
(
0
,
1
),
beta
=
1.0
)
v
=
v
.
view
(
v_grad
.
size
(
0
),
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
v_grad
=
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
weights_v_grad
=
torch
.
mm
(
input
,
v_grad
)
return
v
,
weights_k_grad
,
weights_v_grad
class
SelfAttentionLinears
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input
,
weights_q
,
weights_k
,
weights_v
,
scale
):
s
=
Variable
(
torch
.
tensor
([
scale
]))
ctx
.
save_for_backward
(
input
,
weights_q
,
weights_k
,
weights_v
,
s
)
q
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_q
,
beta
=
0.0
,
alpha
=
s
[
0
])
q
=
q
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
k
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_k
,
beta
=
0.0
,
alpha
=
1.0
)
k
=
k
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
v
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_v
,
beta
=
0.0
,
alpha
=
1.0
)
v
=
v
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
return
q
.
detach
(),
k
.
detach
(),
v
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
q_grad
,
k_grad
,
v_grad
):
input
,
weights_q
,
weights_k
,
weights_v
,
s
=
ctx
.
saved_tensors
input
=
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)).
transpose
(
0
,
1
)
q
=
torch
.
addmm
(
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
weights_q
.
transpose
(
0
,
1
),
beta
=
0.0
,
alpha
=
s
[
0
])
q_grad
=
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
))
weights_q_grad
=
torch
.
addmm
(
weights_q
,
input
,
q_grad
,
beta
=
0.0
,
alpha
=
s
[
0
])
k
=
q
.
addmm_
(
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
)),
weights_k
.
transpose
(
0
,
1
),
beta
=
1.0
)
k_grad
=
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
))
weights_k_grad
=
torch
.
mm
(
input
,
k_grad
)
v
=
k
.
addmm_
(
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
)),
weights_v
.
transpose
(
0
,
1
),
beta
=
1.0
)
v
=
v
.
view
(
v_grad
.
size
(
0
),
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
v_grad
=
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
weights_v_grad
=
torch
.
mm
(
input
,
v_grad
)
return
v
,
weights_q_grad
,
weights_k_grad
,
weights_v_grad
,
None
class
StridedBmm1Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input1
,
input2
):
ctx
.
save_for_backward
(
input1
,
input2
)
output
=
torch
.
empty
((
input1
.
size
(
0
),
input1
.
size
(
1
),
input2
.
size
(
2
)),
dtype
=
input1
.
dtype
,
device
=
torch
.
device
(
'cuda'
))
if
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
output
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
output
,
1.0
,
input1
,
input2
)
else
:
output
=
torch
.
bmm
(
input1
,
input2
,
out
=
output
)
return
output
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
grad_output
):
input1
,
input2
=
ctx
.
saved_tensors
grad_input1
=
torch
.
empty
((
input1
.
size
(
1
),
input2
.
size
(
0
),
input1
.
size
(
2
)),
dtype
=
input1
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
grad_input2
=
torch
.
empty
((
input2
.
size
(
2
),
input2
.
size
(
0
),
input2
.
size
(
1
)),
dtype
=
input2
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
if
(
grad_output
.
dtype
==
torch
.
float16
)
and
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
grad_input1
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input1
,
1.0
,
grad_output
,
input2
.
transpose
(
1
,
2
))
grad_input2
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input2
,
1.0
,
grad_output
.
transpose
(
1
,
2
),
input1
)
grad_input2
=
grad_input2
.
transpose
(
1
,
2
)
else
:
grad_input1
=
torch
.
bmm
(
grad_output
,
input2
.
transpose
(
1
,
2
),
out
=
grad_input1
)
grad_input2
=
torch
.
bmm
(
grad_output
.
transpose
(
1
,
2
),
input1
,
out
=
grad_input2
).
transpose
(
1
,
2
)
return
grad_input1
,
grad_input2
class
StridedBmm2Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input1
,
input2
):
ctx
.
save_for_backward
(
input1
,
input2
)
output
=
torch
.
empty
((
input1
.
size
(
1
),
input1
.
size
(
0
),
input2
.
size
(
2
)),
dtype
=
input1
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
if
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
output
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
output
,
1.0
,
input1
,
input2
)
else
:
output
=
torch
.
bmm
(
input1
,
input2
,
out
=
output
)
return
output
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
grad_output
):
input1
,
input2
=
ctx
.
saved_tensors
grad_input2
=
torch
.
empty
((
input2
.
size
(
1
),
input2
.
size
(
0
),
input2
.
size
(
2
)),
dtype
=
input2
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
grad_input1
=
torch
.
empty
((
input1
.
size
(
0
),
input1
.
size
(
1
),
input1
.
size
(
2
)),
dtype
=
input2
.
dtype
,
device
=
torch
.
device
(
'cuda'
))
if
(
grad_output
.
dtype
==
torch
.
float16
)
and
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
grad_input1
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input1
,
1.0
,
grad_output
,
input2
.
transpose
(
1
,
2
))
grad_input2
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input2
,
1.0
,
input1
.
transpose
(
1
,
2
),
grad_output
)
else
:
grad_input1
=
torch
.
bmm
(
grad_output
,
input2
.
transpose
(
1
,
2
))
grad_input2
=
torch
.
bmm
(
input1
.
transpose
(
1
,
2
),
grad_output
,
out
=
grad_input2
)
return
grad_input1
,
grad_input2
def
query_linear
(
input
:
Tensor
,
weights_q
:
Tensor
,
scale
:
float
):
if
not
torch
.
jit
.
is_scripting
():
return
QueryLinear
.
apply
(
input
,
weights_q
,
scale
)
else
:
q
=
scale
*
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_q
)
q
=
q
.
view
(
input
.
shape
)
return
q
def
key_value_linears
(
input
:
Tensor
,
weights_k
:
Tensor
,
weights_v
:
Tensor
):
if
not
torch
.
jit
.
is_scripting
():
return
KeyValueLinears
.
apply
(
input
,
weights_k
,
weights_v
)
else
:
k
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_k
)
k
=
k
.
view
(
input
.
shape
)
v
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_v
)
v
=
v
.
view
(
input
.
shape
)
return
k
,
v
def
self_attn_linears
(
input
:
Tensor
,
weights_q
:
Tensor
,
weights_k
:
Tensor
,
weights_v
:
Tensor
,
scale
:
float
):
if
not
torch
.
jit
.
is_scripting
():
return
SelfAttentionLinears
.
apply
(
input
,
weights_q
,
weights_k
,
weights_v
,
scale
)
else
:
q
=
scale
*
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_q
)
q
=
q
.
view
(
input
.
shape
)
k
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_k
)
k
=
k
.
view
(
input
.
shape
)
v
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_v
)
v
=
v
.
view
(
input
.
shape
)
return
q
,
k
,
v
def
strided_bmm1
(
input1
:
Tensor
,
input2
:
Tensor
):
if
not
torch
.
jit
.
is_scripting
():
return
StridedBmm1Func
.
apply
(
input1
,
input2
)
else
:
return
torch
.
einsum
(
'ijk,ikn->ijn'
,
input1
,
input2
)
def
strided_bmm2
(
input1
:
Tensor
,
input2
:
Tensor
):
if
not
torch
.
jit
.
is_scripting
():
return
StridedBmm2Func
.
apply
(
input1
,
input2
)
else
:
return
torch
.
einsum
(
'ijk,ikn->ijn'
,
input1
,
input2
)
class
MultiheadAttention
(
nn
.
Module
):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout
=
0.
,
bias
=
False
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
_mask
=
torch
.
empty
(
0
)
#self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self
.
in_proj_weight_q
=
Parameter
(
torch
.
Tensor
(
embed_dim
,
embed_dim
))
self
.
in_proj_weight_k
=
Parameter
(
torch
.
Tensor
(
embed_dim
,
embed_dim
))
self
.
in_proj_weight_v
=
Parameter
(
torch
.
Tensor
(
embed_dim
,
embed_dim
))
if
bias
:
#self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self
.
in_proj_bias_q
=
Parameter
(
torch
.
Tensor
(
embed_dim
))
self
.
in_proj_bias_k
=
Parameter
(
torch
.
Tensor
(
embed_dim
))
self
.
in_proj_bias_v
=
Parameter
(
torch
.
Tensor
(
embed_dim
))
else
:
#self.register_parameter('in_proj_bias', None)
self
.
register_parameter
(
'in_proj_bias_k'
,
None
)
self
.
register_parameter
(
'in_proj_bias_q'
,
None
)
self
.
register_parameter
(
'in_proj_bias_v'
,
None
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
cache_id
=
str
(
id
(
self
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
#nn.init.xavier_uniform_(self.in_proj_weight)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_q
)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_k
)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_v
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
if
self
.
in_proj_bias_k
is
not
None
:
#nn.init.constant_(self.in_proj_bias, 0.)
nn
.
init
.
constant_
(
self
.
in_proj_bias_q
,
0.
)
nn
.
init
.
constant_
(
self
.
in_proj_bias_k
,
0.
)
nn
.
init
.
constant_
(
self
.
in_proj_bias_v
,
0.
)
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.
)
def
forward
(
self
,
query
:
Tensor
,
key
:
Tensor
,
value
:
Tensor
,
mask_future_timesteps
:
bool
,
key_padding_mask
:
Optional
[
Tensor
],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]],
need_weights
:
bool
,
static_kv
:
bool
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if
torch
.
jit
.
is_scripting
():
kv_same
=
torch
.
equal
(
key
,
value
)
qkv_same
=
torch
.
equal
(
query
,
value
)
and
kv_same
else
:
qkv_same
,
kv_same
=
self
.
_fast_same_check
(
query
,
key
,
value
)
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
key
.
size
()
==
value
.
size
()
k
=
v
=
query
.
new_empty
(
0
)
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
else
:
saved_state
=
None
if
qkv_same
:
# self-attention
q
,
k
,
v
=
self_attn_linears
(
query
,
self
.
in_proj_weight_q
,
self
.
in_proj_weight_k
,
self
.
in_proj_weight_v
,
self
.
scaling
)
elif
kv_same
:
# encoder-decoder attention
q
=
query_linear
(
query
,
self
.
in_proj_weight_q
,
self
.
scaling
)
if
not
(
saved_state
is
not
None
and
'prev_key'
in
saved_state
and
static_kv
):
k
,
v
=
key_value_linears
(
key
,
self
.
in_proj_weight_k
,
self
.
in_proj_weight_v
)
else
:
q
=
torch
.
addmm
(
query
.
view
(
query
.
size
(
0
)
*
query
.
size
(
1
),
query
.
size
(
2
)),
query
.
view
(
query
.
size
(
0
)
*
query
.
size
(
1
),
query
.
size
(
2
)),
self
.
in_proj_weight_q
,
beta
=
0.0
,
alpha
=
self
.
scaling
)
if
not
(
saved_state
is
not
None
and
'prev_key'
in
saved_state
and
static_kv
):
k
=
F
.
linear
(
key
,
self
.
in_proj_weight_k
,
self
.
in_proj_bias_k
)
v
=
F
.
linear
(
value
,
self
.
in_proj_weight_v
,
self
.
in_proj_bias_v
)
if
saved_state
is
not
None
:
if
'prev_key'
in
saved_state
:
k
=
torch
.
cat
((
saved_state
[
'prev_key'
],
k
),
dim
=
0
)
if
'prev_value'
in
saved_state
:
v
=
torch
.
cat
((
saved_state
[
'prev_value'
],
v
),
dim
=
0
)
saved_state
[
'prev_key'
]
=
k
saved_state
[
'prev_value'
]
=
v
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
src_len
=
k
.
size
(
0
)
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
q
=
q
.
contiguous
().
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
v
=
v
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
attn_weights
=
strided_bmm1
(
q
,
k
.
transpose
(
1
,
2
))
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
# only apply masking at training time (when incremental state is None)
if
mask_future_timesteps
and
incremental_state
is
None
:
assert
query
.
size
()
==
key
.
size
(),
\
'mask_future_timesteps only applies to self-attention'
attn_weights
+=
self
.
buffered_mask
(
attn_weights
).
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
float
().
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
),
float
(
'-inf'
),
).
type_as
(
attn_weights
)
# FP16 support: cast to float and back
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn
=
strided_bmm2
(
attn_weights
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
if
need_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
sum
(
dim
=
1
)
/
self
.
num_heads
else
:
attn_weights
=
attn_weights
.
new_empty
(
0
)
# Can't set to None because jit script reasons
return
attn
,
attn_weights
def
in_proj_qkv
(
self
,
query
):
return
self
.
_in_proj
(
query
).
chunk
(
3
,
dim
=-
1
)
def
in_proj_kv
(
self
,
key
):
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
).
chunk
(
2
,
dim
=-
1
)
def
in_proj_q
(
self
,
query
):
return
self
.
_in_proj
(
query
,
end
=
self
.
embed_dim
)
def
in_proj_k
(
self
,
key
):
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
,
end
=
2
*
self
.
embed_dim
)
def
in_proj_v
(
self
,
value
):
return
self
.
_in_proj
(
value
,
start
=
2
*
self
.
embed_dim
)
def
_in_proj
(
self
,
input
,
start
=
None
,
end
=
None
):
weight
=
self
.
in_proj_weight
bias
=
self
.
in_proj_bias
if
end
is
not
None
:
weight
=
weight
[:
end
,
:]
if
bias
is
not
None
:
bias
=
bias
[:
end
]
if
start
is
not
None
:
weight
=
weight
[
start
:,
:]
if
bias
is
not
None
:
bias
=
bias
[
start
:]
return
F
.
linear
(
input
,
weight
,
bias
)
def
buffered_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
-
1
)
if
self
.
_mask
.
size
(
0
)
==
0
:
#TODO: try torch.new_full instead
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new_empty
(
dim
,
dim
)),
1
)
if
self
.
_mask
.
size
(
0
)
<
dim
:
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_mask
[:
dim
,
:
dim
]
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
input_buffer
[
k
]
=
input_buffer
[
k
].
index_select
(
1
,
new_order
)
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
def
_get_input_buffer
(
self
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]]):
if
incremental_state
is
None
or
self
.
cache_id
not
in
incremental_state
:
return
{}
return
incremental_state
[
self
.
cache_id
]
def
_set_input_buffer
(
self
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]],
buffer
:
Dict
[
str
,
Tensor
]):
if
incremental_state
is
not
None
:
incremental_state
[
self
.
cache_id
]
=
buffer
@
torch
.
jit
.
unused
def
_fast_same_check
(
self
,
q
,
k
,
v
):
qkv_same
=
q
.
data_ptr
()
==
k
.
data_ptr
()
==
v
.
data_ptr
()
kv_same
=
k
.
data_ptr
()
==
v
.
data_ptr
()
return
qkv_same
,
kv_same
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm.cpp
deleted
100644 → 0
View file @
c056df78
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <torch/torch.h>
#include <vector>
at
::
Tensor
strided_batched_gemm_cuda
(
float
beta
,
at
::
Tensor
in_result
,
float
alpha
,
at
::
Tensor
batch1
,
at
::
Tensor
batch2
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at
::
Tensor
strided_batched_gemm
(
float
beta
,
at
::
Tensor
in_result
,
float
alpha
,
at
::
Tensor
batch1
,
at
::
Tensor
batch2
)
{
//CHECK_INPUT(in_result);
//CHECK_INPUT(batch1);
//CHECK_INPUT(batch2);
AT_ASSERTM
(
in_result
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
batch1
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
batch2
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
in_result
.
size
(
0
)
==
batch1
.
size
(
0
),
"equal number of batches expected"
);
AT_ASSERTM
(
in_result
.
size
(
0
)
==
batch2
.
size
(
0
),
"equal number of batches expected"
);
AT_ASSERTM
(
in_result
.
size
(
1
)
==
batch1
.
size
(
1
),
"wrong matrix size"
);
AT_ASSERTM
(
in_result
.
size
(
2
)
==
batch2
.
size
(
2
),
"wrong matrix size"
);
AT_ASSERTM
(
batch1
.
size
(
2
)
==
batch2
.
size
(
1
),
"wrong matrix size"
);
AT_ASSERTM
(
batch1
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
batch2
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
in_result
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
strided_batched_gemm_cuda
(
beta
,
in_result
,
alpha
,
batch1
,
batch2
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_batched_gemm"
,
&
strided_batched_gemm
,
"Special strided batched gemm."
);
}
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm_cuda.cu
deleted
100644 → 0
View file @
c056df78
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "THC/THC.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
if
(
trans
==
't'
)
return
CUBLAS_OP_T
;
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
else
{
THError
(
"trans must be one of: t, n, c"
);
return
CUBLAS_OP_T
;
}
}
void
CublasGemm
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
//cublasSetStream(handle, THCState_getCurrentStream(state));
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
THCublasCheck
(
cublasGemmStridedBatchedEx
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
CUDA_R_16F
,
(
int
)
lda
,
strideA
,
b
,
CUDA_R_16F
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
CUDA_R_16F
,
(
int
)
ldc
,
strideC
,
(
int
)
batchCount
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
}
template
<
cutlass
::
MatrixLayout
::
Kind
A_LAYOUT
,
cutlass
::
MatrixLayout
::
Kind
B_LAYOUT
,
int
SRC_A
,
int
SRC_B
,
int
DST_C
>
void
CutlassGemm_FP32Accum
(
cudaStream_t
stream
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
typedef
cutlass
::
gemm
::
WmmaGemmTraits
<
A_LAYOUT
,
B_LAYOUT
,
cutlass
::
Shape
<
32
,
16
,
16
>
,
half
,
half
,
half
,
cutlass
::
gemm
::
LinearScaling
<
float
>
,
float
,
typename
cutlass
::
gemm
::
WmmaGemmAccumulatorsPerWarp
<
typename
cutlass
::
Shape
<
32
,
16
,
16
>
>::
Shape
,
typename
cutlass
::
Shape
<
16
,
16
,
16
>
,
SRC_A
,
//kScalarsPerLdgA_
SRC_B
,
//kScalarsPerLdgB_
SRC_A
,
//KScalarsPerLdsA_
SRC_B
,
//KScalarsPerLdsB_
DST_C
,
//kScalarsPerLdgCAndStgD_
DST_C
/
2
,
//kScalarsPerStsD_
DST_C
/
2
//kScalarsPerLdsD_
>
WmmaGemmTraits
;
typedef
cutlass
::
gemm
::
Gemm
<
WmmaGemmTraits
>
Gemm
;
typename
Gemm
::
Params
params
;
int
result
=
params
.
initialize
(
m
,
// M dimension for each batch
n
,
// N dimension for each batch
k
,
// K dimension for each batch
alpha
,
// scalar alpha
a
,
lda
,
strideA
,
// distance in memory between the first element of neighboring batch
b
,
ldb
,
strideB
,
// distance in memory between the first element of neighboring batch
beta
,
// scalar beta
c
,
// source matrix C
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
c
,
// destination matrix C (may be different memory than source C matrix)
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
batchCount
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
// Launch the CUTLASS GEMM kernel.
THCudaCheck
(
Gemm
::
launch
(
params
));
}
void
gemm_switch_fp32accum
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
//cudaStream_t stream = THCState_getCurrentStream(state);
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
int
transb_
=
((
transb
==
't'
)
||
(
transb
==
'T'
));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result
// requires (even if the value won't be used).
if
(
n
<=
1
)
*
ldc
=
std
::
max
<
int64_t
>
(
m
,
1
);
if
(
transa_
)
{
if
(
m
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
else
{
if
(
k
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
m
,
1
);
}
if
(
transb_
)
{
if
(
k
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
n
,
1
);
}
else
{
if
(
n
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
}
void
HgemmStridedBatched
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
if
(
(
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
)
)
{
THError
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d"
,
INT_MAX
);
}
adjustLdLevel3
(
transa
,
transb
,
m
,
n
,
k
,
&
lda
,
&
ldb
,
&
ldc
);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
at
::
Tensor
strided_batched_gemm_cuda
(
float
beta
,
at
::
Tensor
in_result
,
float
alpha
,
at
::
Tensor
batch1
,
at
::
Tensor
batch2
)
{
bool
transpose_result
;
char
transpose_batch1
,
transpose_batch2
;
int64_t
lda
,
ldb
,
ldc
;
at
::
Tensor
result
,
input1
,
input2
;
if
(
in_result
.
stride
(
1
)
==
1
)
{
transpose_result
=
false
;
result
=
in_result
;
ldc
=
result
.
stride
(
2
);
}
else
if
(
in_result
.
stride
(
2
)
==
1
)
{
transpose_result
=
true
;
at
::
Tensor
swap
=
batch2
;
batch2
=
batch1
;
batch1
=
swap
;
result
=
in_result
;
ldc
=
result
.
stride
(
1
);
}
else
{
AT_ASSERTM
(
false
,
"result should be contiguous"
);
}
if
(
batch1
.
stride
(
transpose_result
?
2
:
1
)
==
1
&&
batch1
.
stride
(
transpose_result
?
1
:
2
)
!=
0
)
{
transpose_batch1
=
'n'
;
input1
=
batch1
;
lda
=
input1
.
stride
(
transpose_result
?
1
:
2
);
}
else
if
(
batch1
.
stride
(
transpose_result
?
1
:
2
)
==
1
&&
batch1
.
stride
(
transpose_result
?
2
:
1
)
!=
0
)
{
transpose_batch1
=
't'
;
input1
=
batch1
;
lda
=
input1
.
stride
(
transpose_result
?
2
:
1
);
}
else
{
AT_ASSERTM
(
false
,
"input1 should be contiguous"
);
}
if
(
batch2
.
stride
(
transpose_result
?
2
:
1
)
==
1
&&
batch2
.
stride
(
transpose_result
?
1
:
2
)
!=
0
)
{
transpose_batch2
=
'n'
;
input2
=
batch2
;
ldb
=
input2
.
stride
(
transpose_result
?
1
:
2
);
}
else
if
(
batch2
.
stride
(
transpose_result
?
1
:
2
)
==
1
&&
batch2
.
stride
(
transpose_result
?
2
:
1
)
!=
0
)
{
transpose_batch2
=
't'
;
input2
=
batch2
;
ldb
=
input2
.
stride
(
transpose_result
?
2
:
1
);
}
else
{
AT_ASSERTM
(
false
,
"input2 should be contiguous"
);
}
int64_t
num_batches
=
result
.
size
(
0
);
HgemmStridedBatched
(
state
,
transpose_batch1
,
transpose_batch2
,
result
.
size
(
transpose_result
?
2
:
1
),
result
.
size
(
transpose_result
?
1
:
2
),
input1
.
size
(
transpose_result
?
1
:
2
),
alpha
,
static_cast
<
const
half
*>
(
input1
.
data_ptr
()),
lda
,
input1
.
stride
(
0
),
static_cast
<
const
half
*>
(
input2
.
data_ptr
()),
ldb
,
input2
.
stride
(
0
),
beta
,
static_cast
<
half
*>
(
result
.
data_ptr
()),
ldc
,
result
.
stride
(
0
),
num_batches
);
return
in_result
;
}
PyTorch/NLP/Transformer/fairseq/optim/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_optimizer
import
FairseqOptimizer
OPTIMIZER_REGISTRY
=
{}
OPTIMIZER_CLASS_NAMES
=
set
()
def
build_optimizer
(
args
,
params
):
params
=
filter
(
lambda
p
:
p
.
requires_grad
,
params
)
return
OPTIMIZER_REGISTRY
[
args
.
optimizer
](
args
,
params
)
def
register_optimizer
(
name
):
"""Decorator to register a new optimizer."""
def
register_optimizer_cls
(
cls
):
if
name
in
OPTIMIZER_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate optimizer ({})'
.
format
(
name
))
if
not
issubclass
(
cls
,
FairseqOptimizer
):
raise
ValueError
(
'Optimizer ({}: {}) must extend FairseqOptimizer'
.
format
(
name
,
cls
.
__name__
))
if
cls
.
__name__
in
OPTIMIZER_CLASS_NAMES
:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise
ValueError
(
'Cannot register optimizer with duplicate class name ({})'
.
format
(
cls
.
__name__
))
OPTIMIZER_REGISTRY
[
name
]
=
cls
OPTIMIZER_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_optimizer_cls
# automatically import any Python files in the optim/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.optim.'
+
module
)
PyTorch/NLP/Transformer/fairseq/optim/adam.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
FairseqOptimizer
,
register_optimizer
from
apex.optimizers.fused_adam
import
FusedAdam
@
register_optimizer
(
'adam'
)
class
FairseqAdam
(
FairseqOptimizer
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
(
args
,
params
)
self
.
_optimizer
=
FusedAdam
(
params
,
**
self
.
optimizer_config
)
@
staticmethod
def
add_args
(
parser
):
"""Add optimizer-specific arguments to the parser."""
parser
.
add_argument
(
'--adam-betas'
,
default
=
(
0.9
,
0.999
),
nargs
=
2
,
type
=
float
,
metavar
=
'B1 B2'
,
help
=
'betas for Adam optimizer'
)
parser
.
add_argument
(
'--adam-eps'
,
type
=
float
,
default
=
1e-8
,
metavar
=
'D'
,
help
=
'epsilon for Adam optimizer'
)
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return
{
'lr'
:
self
.
args
.
lr
[
0
],
'betas'
:
self
.
args
.
adam_betas
,
'eps'
:
self
.
args
.
adam_eps
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
PyTorch/NLP/Transformer/fairseq/optim/fairseq_optimizer.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch.optim
class
FairseqOptimizer
(
object
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
()
self
.
args
=
args
self
.
params
=
params
@
staticmethod
def
add_args
(
parser
):
"""Add optimizer-specific arguments to the parser."""
pass
@
property
def
optimizer
(
self
):
"""Return a torch.optim.optimizer.Optimizer instance."""
if
not
hasattr
(
self
,
'_optimizer'
):
raise
NotImplementedError
if
not
isinstance
(
self
.
_optimizer
,
torch
.
optim
.
Optimizer
):
raise
ValueError
(
'_optimizer must be an instance of torch.optim.Optimizer'
)
return
self
.
_optimizer
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise
NotImplementedError
def
get_lr
(
self
):
"""Return the current learning rate."""
return
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
def
set_lr
(
self
,
lr
):
"""Set the learning rate."""
for
param_group
in
self
.
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
def
state_dict
(
self
):
"""Return the optimizer's state dict."""
return
self
.
optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self
.
optimizer
.
load_state_dict
(
state_dict
)
# override learning rate, momentum, etc. with latest values
for
group
in
self
.
optimizer
.
param_groups
:
group
.
update
(
self
.
optimizer_config
)
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step."""
return
self
.
optimizer
.
step
(
closure
)
def
zero_grad
(
self
):
"""Clears the gradients of all optimized parameters."""
for
group
in
self
.
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
return
self
.
optimizer
.
zero_grad
()
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_lr_scheduler
import
FairseqLRScheduler
LR_SCHEDULER_REGISTRY
=
{}
def
build_lr_scheduler
(
args
,
optimizer
):
return
LR_SCHEDULER_REGISTRY
[
args
.
lr_scheduler
](
args
,
optimizer
)
def
register_lr_scheduler
(
name
):
"""Decorator to register a new LR scheduler."""
def
register_lr_scheduler_cls
(
cls
):
if
name
in
LR_SCHEDULER_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate LR scheduler ({})'
.
format
(
name
))
if
not
issubclass
(
cls
,
FairseqLRScheduler
):
raise
ValueError
(
'LR Scheduler ({}: {}) must extend FairseqLRScheduler'
.
format
(
name
,
cls
.
__name__
))
LR_SCHEDULER_REGISTRY
[
name
]
=
cls
return
cls
return
register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.optim.lr_scheduler.'
+
module
)
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/fixed_schedule.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'fixed'
)
class
FixedSchedule
(
FairseqLRScheduler
):
"""Decay the LR on a fixed schedule."""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
# set defaults
args
.
warmup_updates
=
getattr
(
args
,
'warmup_updates'
,
0
)
or
0
self
.
lr
=
args
.
lr
[
0
]
if
args
.
warmup_updates
>
0
:
self
.
warmup_factor
=
1.
/
args
.
warmup_updates
else
:
self
.
warmup_factor
=
1
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--force-anneal'
,
'--fa'
,
type
=
int
,
metavar
=
'N'
,
help
=
'force annealing at specified epoch'
)
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
def
get_next_lr
(
self
,
epoch
):
lrs
=
self
.
args
.
lr
if
self
.
args
.
force_anneal
is
None
or
epoch
<
self
.
args
.
force_anneal
:
# use fixed LR schedule
next_lr
=
lrs
[
min
(
epoch
,
len
(
lrs
)
-
1
)]
else
:
# annneal based on lr_shrink
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lr_shrink
**
(
epoch
+
1
-
self
.
args
.
force_anneal
)
return
next_lr
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
self
.
lr
=
self
.
get_next_lr
(
epoch
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
self
.
args
.
warmup_updates
>
0
and
num_updates
<=
self
.
args
.
warmup_updates
:
self
.
warmup_factor
=
num_updates
/
float
(
self
.
args
.
warmup_updates
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.optim.lr_scheduler
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'reduce_lr_on_plateau'
)
class
ReduceLROnPlateau
(
FairseqLRScheduler
):
"""Decay the LR by a factor every time the validation loss plateaus."""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
' Consider --lr-scheduler=fixed instead.'
)
self
.
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
self
.
optimizer
.
optimizer
,
patience
=
0
,
factor
=
args
.
lr_shrink
)
def
state_dict
(
self
):
"""Return the LR scheduler state dict."""
return
{
'best'
:
self
.
lr_scheduler
.
best
,
'last_epoch'
:
self
.
lr_scheduler
.
last_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
"""Load an LR scheduler state dict."""
self
.
lr_scheduler
.
best
=
state_dict
[
'best'
]
if
'last_epoch'
in
state_dict
:
self
.
lr_scheduler
.
last_epoch
=
state_dict
[
'last_epoch'
]
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
if
val_loss
is
not
None
:
self
.
lr_scheduler
.
step
(
val_loss
,
epoch
)
else
:
self
.
lr_scheduler
.
last_epoch
=
epoch
return
self
.
optimizer
.
get_lr
()
PyTorch/NLP/Transformer/fairseq/options.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
torch
from
fairseq.models
import
ARCH_MODEL_REGISTRY
,
ARCH_CONFIG_REGISTRY
from
fairseq.criterions
import
CRITERION_REGISTRY
from
fairseq.optim
import
OPTIMIZER_REGISTRY
from
fairseq.optim.lr_scheduler
import
LR_SCHEDULER_REGISTRY
def
get_training_parser
():
parser
=
get_parser
(
'Trainer'
)
add_dataset_args
(
parser
,
train
=
True
,
gen
=
True
)
add_distributed_training_args
(
parser
)
add_model_args
(
parser
)
add_optimization_args
(
parser
)
add_checkpoint_args
(
parser
)
add_inference_args
(
parser
)
add_perf_args
(
parser
)
return
parser
def
get_inference_parser
():
parser
=
get_parser
(
'Generation'
)
add_dataset_args
(
parser
,
gen
=
True
)
add_inference_args
(
parser
)
add_perf_args
(
parser
)
return
parser
def
parse_args_and_arch
(
parser
,
input_args
=
None
,
parse_known
=
False
):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args
,
_
=
parser
.
parse_known_args
(
input_args
)
# Add model-specific args to parser.
if
hasattr
(
args
,
'arch'
):
model_specific_group
=
parser
.
add_argument_group
(
'Model-specific configuration'
,
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default
=
argparse
.
SUPPRESS
,
)
ARCH_MODEL_REGISTRY
[
args
.
arch
].
add_args
(
model_specific_group
)
# Add *-specific args to parser.
if
hasattr
(
args
,
'optimizer'
):
OPTIMIZER_REGISTRY
[
args
.
optimizer
].
add_args
(
parser
)
if
hasattr
(
args
,
'lr_scheduler'
):
LR_SCHEDULER_REGISTRY
[
args
.
lr_scheduler
].
add_args
(
parser
)
# Parse a second time.
if
parse_known
:
args
,
extra
=
parser
.
parse_known_args
(
input_args
)
else
:
args
=
parser
.
parse_args
(
input_args
)
extra
=
None
# Post-process args.
if
hasattr
(
args
,
'max_sentences_valid'
)
and
args
.
max_sentences_valid
is
None
:
args
.
max_sentences_valid
=
args
.
max_sentences
args
.
max_positions
=
(
args
.
max_source_positions
,
args
.
max_target_positions
)
if
hasattr
(
args
,
'target_bleu'
)
and
(
args
.
online_eval
or
args
.
target_bleu
)
and
not
args
.
remove_bpe
:
args
.
remove_bpe
=
'@@ '
# Apply architecture configuration.
if
hasattr
(
args
,
'arch'
):
ARCH_CONFIG_REGISTRY
[
args
.
arch
](
args
)
if
parse_known
:
return
args
,
extra
else
:
return
args
def
get_parser
(
desc
):
parser
=
argparse
.
ArgumentParser
(
description
=
'Facebook AI Research Sequence-to-Sequence Toolkit -- '
+
desc
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
500
,
metavar
=
'N'
,
help
=
'print aggregated stats and flush json log every N iteration'
)
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
parser
.
add_argument
(
'--amp'
,
action
=
'store_true'
,
help
=
'use Automatic Mixed Precision'
)
parser
.
add_argument
(
'--stat-file'
,
type
=
str
,
default
=
'run_log.json'
,
help
=
'Name of the file containing DLLogger output'
)
parser
.
add_argument
(
'--save-dir'
,
metavar
=
'DIR'
,
default
=
'results'
,
help
=
'path to save checkpoints and logs'
)
parser
.
add_argument
(
'--do-sanity-check'
,
action
=
'store_true'
,
help
=
'Perform evaluation on test set before running the training'
)
return
parser
def
add_dataset_args
(
parser
,
train
=
False
,
gen
=
False
):
group
=
parser
.
add_argument_group
(
'Dataset and data loading'
)
group
.
add_argument
(
'--max-tokens'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of tokens in a batch'
)
group
.
add_argument
(
'--max-sentences'
,
'--batch-size'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a batch'
)
parser
.
add_argument
(
'-s'
,
'--source-lang'
,
default
=
None
,
metavar
=
'SRC'
,
help
=
'source language'
)
parser
.
add_argument
(
'-t'
,
'--target-lang'
,
default
=
None
,
metavar
=
'TARGET'
,
help
=
'target language'
)
parser
.
add_argument
(
'--raw-text'
,
action
=
'store_true'
,
help
=
'load raw text dataset'
)
parser
.
add_argument
(
'--left-pad-source'
,
default
=
True
,
type
=
bool
,
metavar
=
'BOOL'
,
help
=
'pad the source on the left (default: True)'
)
parser
.
add_argument
(
'--left-pad-target'
,
default
=
False
,
type
=
bool
,
metavar
=
'BOOL'
,
help
=
'pad the target on the left (default: False)'
)
parser
.
add_argument
(
'--max-source-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the source sequence'
)
parser
.
add_argument
(
'--max-target-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the target sequence'
)
parser
.
add_argument
(
'--pad-sequence'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'Pad sequences to a multiple of N'
)
if
train
:
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to data directory'
)
group
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
choices
=
[
'train'
,
'valid'
,
'test'
],
help
=
'data subset to use for training (train, valid, test)'
)
group
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)'
)
group
.
add_argument
(
'--max-sentences-valid'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a validation batch'
' (defaults to --max-sentences)'
)
if
gen
:
group
.
add_argument
(
'--gen-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'data subset to generate (train, valid, test)'
)
group
.
add_argument
(
'--num-shards'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'shard generation over N shards'
)
group
.
add_argument
(
'--shard-id'
,
default
=
0
,
type
=
int
,
metavar
=
'ID'
,
help
=
'id of the shard to generate (id < num_shards)'
)
return
group
def
add_distributed_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Distributed training'
)
group
.
add_argument
(
'--distributed-world-size'
,
type
=
int
,
metavar
=
'N'
,
default
=
torch
.
cuda
.
device_count
(),
help
=
'total number of GPUs across all nodes (default: all visible GPUs)'
)
group
.
add_argument
(
'--distributed-rank'
,
default
=
os
.
getenv
(
'LOCAL_RANK'
,
0
),
type
=
int
,
help
=
'rank of the current worker'
)
group
.
add_argument
(
'--local_rank'
,
default
=
0
,
type
=
int
,
help
=
'rank of the current worker'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
type
=
str
,
help
=
'distributed backend'
)
group
.
add_argument
(
'--distributed-init-method'
,
default
=
None
,
type
=
str
,
help
=
'typically tcp://hostname:port that will be used to '
'establish initial connetion'
)
group
.
add_argument
(
'--distributed-port'
,
default
=-
1
,
type
=
int
,
help
=
'port number (not required if using --distributed-init-method)'
)
group
.
add_argument
(
'--device-id'
,
default
=
0
,
type
=
int
,
help
=
'which GPU to use (usually configured automatically)'
)
return
group
def
add_optimization_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Optimization'
)
group
.
add_argument
(
'--max-epoch'
,
'--me'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified epoch'
)
group
.
add_argument
(
'--max-update'
,
'--mu'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified update'
)
group
.
add_argument
(
'--target-bleu'
,
default
=
0.0
,
type
=
float
,
metavar
=
'TARGET'
,
help
=
'force stop training after reaching target bleu'
)
group
.
add_argument
(
'--clip-norm'
,
default
=
25
,
type
=
float
,
metavar
=
'NORM'
,
help
=
'clip threshold of gradients'
)
group
.
add_argument
(
'--update-freq'
,
default
=
[
1
],
nargs
=
'+'
,
type
=
int
,
help
=
'update parameters every N_i batches, when in epoch i'
)
# Optimizer definitions can be found under fairseq/optim/
group
.
add_argument
(
'--optimizer'
,
default
=
'nag'
,
metavar
=
'OPT'
,
choices
=
OPTIMIZER_REGISTRY
.
keys
(),
help
=
'optimizer: {} (default: nag)'
.
format
(
', '
.
join
(
OPTIMIZER_REGISTRY
.
keys
())))
group
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
[
0.25
],
nargs
=
'+'
,
type
=
float
,
help
=
'learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)'
)
group
.
add_argument
(
'--momentum'
,
default
=
0.99
,
type
=
float
,
metavar
=
'M'
,
help
=
'momentum factor'
)
group
.
add_argument
(
'--weight-decay'
,
'--wd'
,
default
=
0.0
,
type
=
float
,
metavar
=
'WD'
,
help
=
'weight decay'
)
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group
.
add_argument
(
'--lr-scheduler'
,
default
=
'reduce_lr_on_plateau'
,
help
=
'learning rate scheduler: {} (default: reduce_lr_on_plateau)'
.
format
(
', '
.
join
(
LR_SCHEDULER_REGISTRY
.
keys
())))
group
.
add_argument
(
'--lr-shrink'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LS'
,
help
=
'learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)'
)
group
.
add_argument
(
'--min-lr'
,
default
=
1e-5
,
type
=
float
,
metavar
=
'LR'
,
help
=
'minimum learning rate'
)
# Criterion args
parser
.
add_argument
(
'--label-smoothing'
,
default
=
0.
,
type
=
float
,
metavar
=
'D'
,
help
=
'epsilon for label smoothing, 0 means no label smoothing'
)
return
group
def
add_checkpoint_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Checkpointing'
)
group
.
add_argument
(
'--restore-file'
,
default
=
'checkpoint_last.pt'
,
help
=
'filename in save-dir from which to load checkpoint'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'save a checkpoint every N epochs'
)
group
.
add_argument
(
'--no-save'
,
action
=
'store_true'
,
help
=
'don
\'
t save models or checkpoints'
)
group
.
add_argument
(
'--no-epoch-checkpoints'
,
action
=
'store_true'
,
help
=
'only store last and best checkpoints'
)
group
.
add_argument
(
'--validate-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'validate every N epochs'
)
return
group
def
add_common_eval_args
(
group
):
group
.
add_argument
(
'--path'
,
metavar
=
'FILE'
,
help
=
'path(s) to model file(s), colon separated'
)
group
.
add_argument
(
'--file'
,
metavar
=
'FILE'
,
default
=
None
,
type
=
str
,
help
=
'path to a file with input data for inference'
)
group
.
add_argument
(
'--remove-bpe'
,
nargs
=
'?'
,
const
=
'@@ '
,
default
=
None
,
help
=
'remove BPE tokens before scoring'
)
group
.
add_argument
(
'--cpu'
,
action
=
'store_true'
,
help
=
'generate on CPU'
)
group
.
add_argument
(
'--quiet'
,
action
=
'store_true'
,
help
=
'only print final scores'
)
def
add_inference_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Generation'
)
add_common_eval_args
(
group
)
group
.
add_argument
(
'--beam'
,
default
=
4
,
type
=
int
,
metavar
=
'N'
,
help
=
'beam size'
)
group
.
add_argument
(
'--nbest'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of hypotheses to output'
)
group
.
add_argument
(
'--max-len-a'
,
default
=
0
,
type
=
float
,
metavar
=
'N'
,
help
=
(
'generate sequences of maximum length ax + b, '
'where x is the source length'
))
group
.
add_argument
(
'--max-len-b'
,
default
=
200
,
type
=
int
,
metavar
=
'N'
,
help
=
(
'generate sequences of maximum length ax + b, '
'where x is the source length'
))
group
.
add_argument
(
'--min-len'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
(
'minimum generation length'
))
group
.
add_argument
(
'--no-early-stop'
,
action
=
'store_true'
,
help
=
(
'continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'
))
group
.
add_argument
(
'--unnormalized'
,
action
=
'store_true'
,
help
=
'compare unnormalized hypothesis scores'
)
group
.
add_argument
(
'--no-beamable-mm'
,
action
=
'store_true'
,
help
=
'don
\'
t use BeamableMM in attention layers'
)
group
.
add_argument
(
'--lenpen'
,
default
=
1
,
type
=
float
,
help
=
'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
)
group
.
add_argument
(
'--unkpen'
,
default
=
0
,
type
=
float
,
help
=
'unknown word penalty: <0 produces more unks, >0 produces fewer'
)
group
.
add_argument
(
'--replace-unk'
,
nargs
=
'?'
,
const
=
True
,
default
=
None
,
help
=
'perform unknown replacement (optionally with alignment dictionary)'
)
group
.
add_argument
(
'--prefix-size'
,
default
=
0
,
type
=
int
,
metavar
=
'PS'
,
help
=
'initialize generation by target prefix of given length'
)
group
.
add_argument
(
'--sampling'
,
action
=
'store_true'
,
help
=
'sample hypotheses instead of using beam search'
)
group
.
add_argument
(
'--sampling-topk'
,
default
=-
1
,
type
=
int
,
metavar
=
'PS'
,
help
=
'sample from top K likely next words instead of all words'
)
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
'temperature for random sampling'
)
group
.
add_argument
(
'--print-alignment'
,
action
=
'store_true'
,
help
=
'if set, uses attention feedback to compute and print alignment to source tokens'
)
group
.
add_argument
(
'--online-eval'
,
action
=
'store_true'
,
help
=
'score model at the end of epoch'
)
group
.
add_argument
(
'--save-predictions'
,
action
=
'store_true'
,
help
=
'Save predictions produced with online evaluation'
)
group
.
add_argument
(
'--test-cased-bleu'
,
action
=
'store_true'
,
help
=
'Use cased bleu for online eval'
)
group
.
add_argument
(
'--bpe-codes'
,
default
=
None
,
type
=
str
,
metavar
=
'CODES'
,
help
=
'file with bpe codes'
)
group
.
add_argument
(
'--buffer-size'
,
default
=
64
,
type
=
int
,
metavar
=
'N'
,
help
=
'read this many sentences into a buffer before processing them'
)
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'use fp16 precision'
)
return
group
def
add_model_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Model configuration'
)
# Model definitions can be found under fairseq/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
group
.
add_argument
(
'--arch'
,
'-a'
,
default
=
'fconv'
,
metavar
=
'ARCH'
,
required
=
True
,
choices
=
ARCH_MODEL_REGISTRY
.
keys
(),
help
=
'model architecture: {} (default: fconv)'
.
format
(
', '
.
join
(
ARCH_MODEL_REGISTRY
.
keys
())),
)
# Criterion definitions can be found under fairseq/criterions/
group
.
add_argument
(
'--criterion'
,
default
=
'cross_entropy'
,
metavar
=
'CRIT'
,
choices
=
CRITERION_REGISTRY
.
keys
(),
help
=
'training criterion: {} (default: cross_entropy)'
.
format
(
', '
.
join
(
CRITERION_REGISTRY
.
keys
())),
)
return
group
def
add_perf_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Performance'
)
group
.
add_argument
(
'--fuse-dropout-add'
,
action
=
'store_true'
,
help
=
'Fuse dropout and residual adds.'
)
group
.
add_argument
(
'--fuse-relu-dropout'
,
action
=
'store_true'
,
help
=
'Fuse Relu and Dropout.'
)
group
.
add_argument
(
'--fuse-layer-norm'
,
action
=
'store_true'
,
help
=
'Use APEX
\'
s FusedLayerNorm instead of torch.nn.LayerNorm'
)
return
group
Prev
1
2
3
4
5
6
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment