Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
3177 deletions
+0
-3177
PyTorch/NLP/Transformer/fairseq/ddp_trainer.py
PyTorch/NLP/Transformer/fairseq/ddp_trainer.py
+0
-305
PyTorch/NLP/Transformer/fairseq/distributed_utils.py
PyTorch/NLP/Transformer/fairseq/distributed_utils.py
+0
-111
PyTorch/NLP/Transformer/fairseq/log_helper.py
PyTorch/NLP/Transformer/fairseq/log_helper.py
+0
-204
PyTorch/NLP/Transformer/fairseq/meters.py
PyTorch/NLP/Transformer/fairseq/meters.py
+0
-87
PyTorch/NLP/Transformer/fairseq/models/__init__.py
PyTorch/NLP/Transformer/fairseq/models/__init__.py
+0
-55
PyTorch/NLP/Transformer/fairseq/models/fairseq_incremental_decoder.py
...Transformer/fairseq/models/fairseq_incremental_decoder.py
+0
-42
PyTorch/NLP/Transformer/fairseq/models/fused_layer_norm.py
PyTorch/NLP/Transformer/fairseq/models/fused_layer_norm.py
+0
-159
PyTorch/NLP/Transformer/fairseq/models/transformer.py
PyTorch/NLP/Transformer/fairseq/models/transformer.py
+0
-621
PyTorch/NLP/Transformer/fairseq/modules/__init__.py
PyTorch/NLP/Transformer/fairseq/modules/__init__.py
+0
-18
PyTorch/NLP/Transformer/fairseq/modules/learned_positional_embedding.py
...ansformer/fairseq/modules/learned_positional_embedding.py
+0
-31
PyTorch/NLP/Transformer/fairseq/modules/multihead_attention.py
...ch/NLP/Transformer/fairseq/modules/multihead_attention.py
+0
-460
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm.cpp
...seq/modules/strided_batched_gemm/strided_batched_gemm.cpp
+0
-61
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm_cuda.cu
...modules/strided_batched_gemm/strided_batched_gemm_cuda.cu
+0
-345
PyTorch/NLP/Transformer/fairseq/optim/__init__.py
PyTorch/NLP/Transformer/fairseq/optim/__init__.py
+0
-46
PyTorch/NLP/Transformer/fairseq/optim/adam.py
PyTorch/NLP/Transformer/fairseq/optim/adam.py
+0
-54
PyTorch/NLP/Transformer/fairseq/optim/fairseq_optimizer.py
PyTorch/NLP/Transformer/fairseq/optim/fairseq_optimizer.py
+0
-94
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/__init__.py
...ch/NLP/Transformer/fairseq/optim/lr_scheduler/__init__.py
+0
-39
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/fixed_schedule.py
.../Transformer/fairseq/optim/lr_scheduler/fixed_schedule.py
+0
-57
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
...former/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
+0
-46
PyTorch/NLP/Transformer/fairseq/options.py
PyTorch/NLP/Transformer/fairseq/options.py
+0
-342
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/Transformer/fairseq/ddp_trainer.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Train a network across multiple GPUs.
"""
import
math
from
collections
import
defaultdict
from
itertools
import
chain
import
torch
import
torch.nn.functional
as
F
from
torch.cuda
import
amp
from
apex.parallel
import
DistributedDataParallel
as
DDP
from
fairseq
import
distributed_utils
,
optim
,
utils
from
fairseq.optim
import
lr_scheduler
from
fairseq.meters
import
TimeMeter
,
AverageMeter
from
fairseq.criterions
import
CRITERION_REGISTRY
import
dllogger
as
DLLogger
class
DDPTrainer
():
"""Main class for data parallel training.
This class supports data parallel training, where multiple workers each
have a full model replica and gradients are accumulated synchronously via
torch.distributed.all_reduce.
"""
def
__init__
(
self
,
args
,
model
):
if
not
torch
.
cuda
.
is_available
():
raise
NotImplementedError
(
'Training on CPU is not supported'
)
self
.
args
=
args
self
.
model
=
model
.
cuda
()
self
.
criterion
=
CRITERION_REGISTRY
[
args
.
criterion
](
args
).
cuda
()
self
.
optimizer
=
optim
.
build_optimizer
(
self
.
args
,
self
.
model
.
parameters
())
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
self
.
scaler
=
amp
.
GradScaler
(
enabled
=
self
.
args
.
amp
,
init_scale
=
2
**
15
)
if
self
.
args
.
distributed_world_size
>
1
:
self
.
model
=
DDP
(
model
)
self
.
_buffered_stats
=
defaultdict
(
lambda
:
[])
self
.
_num_updates
=
0
self
.
_optim_history
=
None
self
.
throughput_meter
=
TimeMeter
()
self
.
avg_loss_meter
=
AverageMeter
()
def
save_checkpoint
(
self
,
filename
,
extra_state
):
"""Save all training state in a checkpoint file."""
if
distributed_utils
.
is_master
(
self
.
args
):
# only save one checkpoint
utils
.
save_state
(
filename
,
self
.
args
,
self
.
get_model
(),
self
.
criterion
,
self
.
optimizer
,
self
.
lr_scheduler
,
self
.
_num_updates
,
self
.
_optim_history
,
extra_state
,
)
def
load_checkpoint
(
self
,
filename
,
load_optim
=
True
):
"""Load all training state from a checkpoint file."""
extra_state
,
optim_history
,
last_optim_state
=
\
utils
.
load_model_state
(
filename
,
self
.
get_model
())
if
last_optim_state
is
not
None
:
# rebuild optimizer after loading model, since params may have changed
#self.optimizer = optim.build_optimizer(self.args, self.model.parameters())
self
.
lr_scheduler
=
lr_scheduler
.
build_lr_scheduler
(
self
.
args
,
self
.
optimizer
)
if
load_optim
:
self
.
_optim_history
=
optim_history
# only reload optimizer and lr_scheduler if they match
last_optim
=
self
.
_optim_history
[
-
1
]
if
last_optim
[
'criterion_name'
]
==
self
.
criterion
.
__class__
.
__name__
:
self
.
lr_scheduler
.
load_state_dict
(
last_optim
[
'lr_scheduler_state'
])
if
last_optim
[
'optimizer_name'
]
==
self
.
optimizer
.
__class__
.
__name__
:
self
.
optimizer
.
load_state_dict
(
last_optim_state
)
self
.
_num_updates
=
last_optim
[
'num_updates'
]
return
extra_state
def
train_step
(
self
,
sample
,
update_params
=
True
,
last_step
=
False
):
"""Do forward, backward and parameter update."""
# Set seed based on args.seed and the update number so that we get
# reproducible results when resuming from checkpoints
seed
=
self
.
args
.
seed
+
self
.
get_num_updates
()
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
self
.
model
.
train
()
if
isinstance
(
self
.
model
,
DDP
):
if
last_step
:
self
.
model
.
disable_allreduce
()
else
:
self
.
model
.
enable_allreduce
()
# forward and backward pass
sample
=
self
.
_prepare_sample
(
sample
)
loss
,
oom_fwd
=
self
.
_forward
(
sample
)
# If this is a last batch forward pass is skipped on some workers
# Batch with sample_size 0 is not accounted for in weighted loss
logging_output
=
{
'ntokens'
:
sample
[
'ntokens'
]
if
sample
is
not
None
else
0
,
'nsentences'
:
sample
[
'target'
].
size
(
0
)
if
sample
is
not
None
else
0
,
'loss'
:
utils
.
item
(
loss
.
data
)
if
loss
is
not
None
else
0
,
}
sample_size
=
sample
[
'ntokens'
]
if
sample
is
not
None
else
0
oom_bwd
=
self
.
_backward
(
loss
)
# buffer stats and logging outputs
self
.
_buffered_stats
[
'sample_sizes'
].
append
(
sample_size
)
self
.
_buffered_stats
[
'logging_outputs'
].
append
(
logging_output
)
self
.
_buffered_stats
[
'ooms_fwd'
].
append
(
oom_fwd
)
self
.
_buffered_stats
[
'ooms_bwd'
].
append
(
oom_bwd
)
# update parameters
if
update_params
and
not
last_step
:
# gather logging outputs from all replicas
sample_sizes
=
self
.
_buffered_stats
[
'sample_sizes'
]
logging_outputs
=
self
.
_buffered_stats
[
'logging_outputs'
]
ooms_fwd
=
self
.
_buffered_stats
[
'ooms_fwd'
]
ooms_bwd
=
self
.
_buffered_stats
[
'ooms_bwd'
]
if
self
.
args
.
distributed_world_size
>
1
:
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
=
map
(
lambda
l
:
list
(
chain
.
from_iterable
(
l
)),
zip
(
*
distributed_utils
.
all_gather_list
(
(
sample_sizes
,
logging_outputs
,
ooms_fwd
,
ooms_bwd
)
))
)
ooms_fwd
=
sum
(
ooms_fwd
)
ooms_bwd
=
sum
(
ooms_bwd
)
ooms
=
ooms_fwd
+
ooms_bwd
# this is always <= distributed_world_size
if
ooms
==
self
.
args
.
distributed_world_size
:
print
(
'| WARNING: OOM in all workers, skipping batch'
)
self
.
zero_grad
()
return
# aggregate stats and logging outputs
grad_denom
=
sum
(
sample_sizes
)
for
p
in
self
.
model
.
parameters
():
if
p
.
requires_grad
and
p
.
grad
is
not
None
:
p
.
grad
/=
grad_denom
self
.
_opt
()
# Handle logging
ntokens
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
self
.
throughput_meter
.
update
(
ntokens
)
info_log_data
=
{
'tokens/s'
:
self
.
throughput_meter
.
avg
,
'tokens'
:
ntokens
,
'loss'
:
sum
(
log
.
get
(
'loss'
,
0
)
for
log
in
logging_outputs
)
/
ntokens
/
math
.
log
(
2
)
}
self
.
avg_loss_meter
.
update
(
info_log_data
[
'loss'
])
debug_log_data
=
{
'batch_size'
:
sum
(
log
.
get
(
'nsentences'
,
0
)
for
log
in
logging_outputs
),
'lr'
:
self
.
get_lr
(),
'grad_denom'
:
grad_denom
,
'updates'
:
1
}
DLLogger
.
log
(
step
=
self
.
_num_updates
,
data
=
info_log_data
,
verbosity
=
0
)
DLLogger
.
log
(
step
=
self
.
_num_updates
,
data
=
debug_log_data
,
verbosity
=
1
)
self
.
clear_buffered_stats
()
def
_forward
(
self
,
sample
):
loss
=
None
oom
=
0
try
:
if
sample
is
not
None
:
with
amp
.
autocast
(
enabled
=
self
.
args
.
amp
):
# calculate loss and sample size
logits
,
_
=
self
.
model
(
**
sample
[
'net_input'
])
target
=
sample
[
'target'
]
probs
=
F
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float32
)
loss
=
self
.
criterion
(
probs
,
target
)
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory in worker {}, skipping batch'
.
format
(
self
.
args
.
distributed_rank
),
force
=
True
)
oom
=
1
loss
=
None
else
:
raise
e
return
loss
,
oom
def
_backward
(
self
,
loss
):
oom
=
0
if
loss
is
not
None
:
try
:
self
.
scaler
.
scale
(
loss
).
backward
()
except
RuntimeError
as
e
:
if
'out of memory'
in
str
(
e
):
print
(
'| WARNING: ran out of memory in worker {}, skipping batch'
.
format
(
self
.
args
.
distributed_rank
),
force
=
True
)
oom
=
1
self
.
zero_grad
()
else
:
raise
e
return
oom
def
_opt
(
self
):
# take an optimization step
self
.
scaler
.
step
(
self
.
optimizer
.
optimizer
)
self
.
scaler
.
update
()
self
.
zero_grad
()
self
.
_num_updates
+=
1
# update learning rate
self
.
lr_scheduler
.
step_update
(
self
.
_num_updates
)
def
valid_step
(
self
,
sample
):
"""Do forward pass in evaluation mode."""
self
.
model
.
eval
()
# forward pass
sample
=
self
.
_prepare_sample
(
sample
)
with
torch
.
no_grad
():
loss
,
oom_fwd
=
self
.
_forward
(
sample
)
logging_output
=
{
'ntokens'
:
sample
[
'ntokens'
]
if
sample
is
not
None
else
0
,
'nsentences'
:
sample
[
'target'
].
size
(
0
)
if
sample
is
not
None
else
0
,
}
loss
=
loss
.
item
()
if
loss
is
not
None
else
0
assert
not
oom_fwd
,
'Ran out of memory during validation'
# gather logging outputs from all GPUs
if
self
.
args
.
distributed_world_size
>
1
:
losses
,
logging_outputs
=
zip
(
*
distributed_utils
.
all_gather_list
(
(
loss
,
logging_output
)
))
else
:
losses
=
[
loss
]
logging_outputs
=
[
logging_output
]
weight
=
sum
(
log
.
get
(
'ntokens'
,
0
)
for
log
in
logging_outputs
)
scaled_loss
=
sum
(
losses
)
/
weight
/
math
.
log
(
2
)
return
scaled_loss
def
dummy_train_step
(
self
,
dummy_batch
):
"""Dummy training step for warming caching allocator."""
self
.
train_step
(
dummy_batch
,
update_params
=
False
)
self
.
zero_grad
()
self
.
clear_buffered_stats
()
def
zero_grad
(
self
):
self
.
optimizer
.
zero_grad
()
def
clear_buffered_stats
(
self
):
self
.
_buffered_stats
.
clear
()
def
lr_step
(
self
,
epoch
,
val_loss
=
None
):
"""Adjust the learning rate based on the validation loss."""
return
self
.
lr_scheduler
.
step
(
epoch
,
val_loss
)
def
lr_step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
return
self
.
lr_scheduler
.
step_update
(
num_updates
)
def
get_lr
(
self
):
"""Get the current learning rate."""
return
self
.
optimizer
.
get_lr
()
def
get_throughput_meter
(
self
):
"""Get the throughput meter"""
return
self
.
throughput_meter
def
get_model
(
self
):
"""Get the model replica."""
return
self
.
model
.
module
if
isinstance
(
self
.
model
,
DDP
)
else
self
.
model
def
get_num_updates
(
self
):
"""Get the number of parameters updates."""
return
self
.
_num_updates
def
_prepare_sample
(
self
,
sample
):
if
not
sample
:
return
None
return
utils
.
move_to_cuda
(
sample
)
PyTorch/NLP/Transformer/fairseq/distributed_utils.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
pickle
import
os
import
socket
import
torch.distributed
from
fairseq
import
utils
def
is_master
(
args
):
return
args
.
distributed_rank
==
0
def
distributed_init
(
args
):
if
args
.
distributed_world_size
==
1
:
raise
ValueError
(
'Cannot initialize distributed with distributed_world_size=1'
)
print
(
'| distributed init (rank {}): {}'
.
format
(
args
.
distributed_rank
,
args
.
distributed_init_method
),
flush
=
True
)
print
(
"| distributed env init. MASTER_ADDR: "
+
os
.
environ
[
'MASTER_ADDR'
]
+
", MASTER_PORT: "
+
os
.
environ
[
'MASTER_PORT'
]
+
", WORLD_SIZE: "
+
os
.
environ
[
'WORLD_SIZE'
]
+
", RANK: "
+
os
.
environ
[
'RANK'
],
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
init_method
=
'env://'
)
print
(
"| distributed init done!"
,
flush
=
True
)
args
.
distributed_world_size
=
int
(
os
.
environ
[
'WORLD_SIZE'
])
args
.
distributed_rank
=
torch
.
distributed
.
get_rank
()
args
.
device_id
=
int
(
os
.
environ
.
get
(
'LOCAL_RANK'
,
args
.
local_rank
))
suppress_output
(
args
)
print
(
'| initialized host {} as rank {} and device id {}'
.
format
(
socket
.
gethostname
(),
args
.
distributed_rank
,
args
.
device_id
))
return
args
.
distributed_rank
def
suppress_output
(
main_args
):
"""Suppress printing on the current device. Force printing with `force=True`."""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print_master
(
*
args
,
**
kwargs
):
if
'force'
in
kwargs
:
kwargs
.
pop
(
'force'
)
builtin_print
(
*
args
,
**
kwargs
)
def
print
(
*
args
,
**
kwargs
):
if
'force'
in
kwargs
:
force
=
kwargs
.
pop
(
'force'
)
if
force
:
builtin_print
(
*
args
,
**
kwargs
)
if
is_master
(
main_args
):
__builtin__
.
print
=
print_master
else
:
__builtin__
.
print
=
print
def
all_gather_list
(
data
,
max_size
=
16384
):
"""Gathers arbitrary data from all nodes into a list."""
world_size
=
torch
.
distributed
.
get_world_size
()
if
not
hasattr
(
all_gather_list
,
'_in_buffer'
)
or
\
max_size
!=
len
(
all_gather_list
.
_in_buffer
):
all_gather_list
.
_in_buffer
=
torch
.
cuda
.
ByteTensor
(
max_size
)
all_gather_list
.
_out_buffers
=
[
torch
.
cuda
.
ByteTensor
(
max_size
)
for
i
in
range
(
world_size
)
]
in_buffer
=
all_gather_list
.
_in_buffer
out_buffers
=
all_gather_list
.
_out_buffers
enc
=
pickle
.
dumps
(
data
)
enc_size
=
len
(
enc
)
if
enc_size
+
2
>
max_size
:
raise
ValueError
(
'encoded data exceeds max_size: {}'
.
format
(
enc_size
+
2
))
assert
max_size
<
255
*
256
in_buffer
[
0
]
=
enc_size
//
255
# this encoding works for max_size < 65k
in_buffer
[
1
]
=
enc_size
%
255
in_buffer
[
2
:
enc_size
+
2
]
=
torch
.
ByteTensor
(
list
(
enc
))
torch
.
distributed
.
all_gather
(
out_buffers
,
in_buffer
.
cuda
())
result
=
[]
for
i
in
range
(
world_size
):
out_buffer
=
out_buffers
[
i
]
size
=
(
255
*
utils
.
item
(
out_buffer
[
0
]))
+
utils
.
item
(
out_buffer
[
1
])
result
.
append
(
pickle
.
loads
(
bytes
(
out_buffer
[
2
:
size
+
2
].
tolist
()))
)
return
result
PyTorch/NLP/Transformer/fairseq/log_helper.py
deleted
100644 → 0
View file @
c056df78
import
os
import
atexit
import
time
import
itertools
from
collections
import
OrderedDict
import
dllogger
from
dllogger
import
Backend
,
JSONStreamBackend
from
tensorboardX
import
SummaryWriter
class
AverageMeter
():
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
updated
=
False
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
value
):
self
.
updated
=
True
if
isinstance
(
value
,
(
tuple
,
list
)):
val
=
value
[
0
]
n
=
value
[
1
]
else
:
val
=
value
n
=
1
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
@
property
def
value
(
self
):
return
self
.
avg
class
PerformanceMeter
():
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
updated
=
False
self
.
start
=
time
.
time
()
self
.
n
=
0
def
update
(
self
,
val
=
1
):
self
.
updated
=
True
self
.
n
+=
val
@
property
def
value
(
self
):
return
self
.
n
/
self
.
elapsed_time
@
property
def
elapsed_time
(
self
):
return
time
.
time
()
-
self
.
start
METRIC
=
{
'average'
:
AverageMeter
,
'performance'
:
PerformanceMeter
}
class
AggregatorBackend
(
Backend
):
def
__init__
(
self
,
verbosity
,
agg_dict
):
super
().
__init__
(
verbosity
=
verbosity
)
agg_dict
=
OrderedDict
({
k
:
v
if
isinstance
(
v
,
(
tuple
,
list
))
else
(
v
,)
for
k
,
v
in
agg_dict
.
items
()})
self
.
metrics
=
OrderedDict
({
k
:
[
METRIC
[
x
]()
for
x
in
v
]
for
k
,
v
in
agg_dict
.
items
()})
self
.
metrics
.
flushed
=
True
self
.
step
=
0
self
.
epoch
=
0
self
.
start_time
=
time
.
time
()
@
property
def
log_level
(
self
):
return
self
.
_log_level
def
metadata
(
self
,
timestamp
,
elapsedtime
,
metric
,
metadata
):
pass
def
_reset_perf_meter
(
self
,
name
):
for
agg
in
self
.
metrics
[
name
]:
if
isinstance
(
agg
,
PerformanceMeter
):
agg
.
reset
()
def
reset_perf_meters
(
self
):
for
name
in
self
.
metrics
.
keys
():
self
.
_reset_perf_meter
(
name
)
def
log
(
self
,
timestamp
,
elapsedtime
,
step
,
data
):
self
.
step
=
step
if
'epoch'
in
data
.
keys
():
self
.
epoch
=
data
[
'epoch'
]
for
k
,
v
in
data
.
items
():
if
k
not
in
self
.
metrics
.
keys
():
continue
self
.
metrics
.
flushed
=
False
for
ag
in
self
.
metrics
[
k
]:
ag
.
update
(
v
)
def
flush
(
self
):
if
self
.
metrics
.
flushed
:
return
result_string
=
'Transformer | epoch {} | step {} |'
.
format
(
self
.
epoch
,
self
.
step
)
for
name
,
aggregators
in
self
.
metrics
.
items
():
for
agg
in
aggregators
:
if
not
agg
.
updated
:
continue
if
isinstance
(
agg
,
AverageMeter
):
_name
=
'avg '
+
name
elif
isinstance
(
agg
,
PerformanceMeter
):
_name
=
name
+
'/s'
result_string
+=
_name
+
' {:.3f} |'
.
format
(
agg
.
value
)
agg
.
reset
()
result_string
+=
'walltime {:.3f} |'
.
format
(
time
.
time
()
-
self
.
start_time
)
self
.
metrics
.
flushed
=
True
print
(
result_string
)
class
TensorBoardBackend
(
Backend
):
def
__init__
(
self
,
verbosity
,
log_dir
):
super
().
__init__
(
verbosity
=
verbosity
)
self
.
summary_writer
=
SummaryWriter
(
log_dir
=
os
.
path
.
join
(
log_dir
,
'TB_summary'
),
flush_secs
=
120
,
max_queue
=
200
)
atexit
.
register
(
self
.
summary_writer
.
close
)
@
property
def
log_level
(
self
):
return
self
.
_log_level
def
metadata
(
self
,
timestamp
,
elapsedtime
,
metric
,
metadata
):
pass
def
log
(
self
,
timestamp
,
elapsedtime
,
step
,
data
):
if
not
isinstance
(
step
,
int
):
return
for
k
,
v
in
data
.
items
():
self
.
summary_writer
.
add_scalar
(
k
,
v
,
step
)
def
flush
(
self
):
pass
def
setup_logger
(
args
):
aggregator_dict
=
OrderedDict
([
(
'loss'
,
'average'
),
(
'weighted_loss'
,
'average'
),
(
'tokens'
,
(
'average'
,
'performance'
)),
(
'updates'
,
'performance'
),
(
'gnorm'
,
'average'
)
])
os
.
makedirs
(
args
.
save_dir
,
exist_ok
=
True
)
log_path
=
os
.
path
.
join
(
args
.
save_dir
,
args
.
stat_file
)
if
os
.
path
.
exists
(
log_path
):
for
i
in
itertools
.
count
():
s_fname
=
args
.
stat_file
.
split
(
'.'
)
fname
=
'.'
.
join
(
s_fname
[:
-
1
])
+
f
'_
{
i
}
.'
+
s_fname
[
-
1
]
if
len
(
s_fname
)
>
1
else
args
.
stat_file
+
f
'.
{
i
}
'
log_path
=
os
.
path
.
join
(
args
.
save_dir
,
fname
)
if
not
os
.
path
.
exists
(
log_path
):
break
if
not
args
.
distributed_world_size
>
1
or
args
.
distributed_rank
==
0
:
dllogger
.
init
(
backends
=
[
JSONStreamBackend
(
verbosity
=
1
,
filename
=
log_path
),
AggregatorBackend
(
verbosity
=
0
,
agg_dict
=
aggregator_dict
),
TensorBoardBackend
(
verbosity
=
1
,
log_dir
=
args
.
save_dir
)])
else
:
dllogger
.
init
(
backends
=
[])
for
k
,
v
in
vars
(
args
).
items
():
dllogger
.
log
(
step
=
'PARAMETER'
,
data
=
{
k
:
v
},
verbosity
=
0
)
container_setup_info
=
get_framework_env_vars
()
dllogger
.
log
(
step
=
'PARAMETER'
,
data
=
container_setup_info
,
verbosity
=
0
)
dllogger
.
metadata
(
'loss'
,
{
'unit'
:
'nat'
,
'GOAL'
:
'MINIMIZE'
,
'STAGE'
:
'TRAIN'
})
dllogger
.
metadata
(
'val_loss'
,
{
'unit'
:
'nat'
,
'GOAL'
:
'MINIMIZE'
,
'STAGE'
:
'VAL'
})
dllogger
.
metadata
(
'speed'
,
{
'unit'
:
'tokens/s'
,
'format'
:
':.3f'
,
'GOAL'
:
'MAXIMIZE'
,
'STAGE'
:
'TRAIN'
})
dllogger
.
metadata
(
'accuracy'
,
{
'unit'
:
'bleu'
,
'format'
:
':.2f'
,
'GOAL'
:
'MAXIMIZE'
,
'STAGE'
:
'VAL'
})
def
get_framework_env_vars
():
return
{
'NVIDIA_PYTORCH_VERSION'
:
os
.
environ
.
get
(
'NVIDIA_PYTORCH_VERSION'
),
'PYTORCH_VERSION'
:
os
.
environ
.
get
(
'PYTORCH_VERSION'
),
'CUBLAS_VERSION'
:
os
.
environ
.
get
(
'CUBLAS_VERSION'
),
'NCCL_VERSION'
:
os
.
environ
.
get
(
'NCCL_VERSION'
),
'CUDA_DRIVER_VERSION'
:
os
.
environ
.
get
(
'CUDA_DRIVER_VERSION'
),
'CUDNN_VERSION'
:
os
.
environ
.
get
(
'CUDNN_VERSION'
),
'CUDA_VERSION'
:
os
.
environ
.
get
(
'CUDA_VERSION'
),
'NVIDIA_PIPELINE_ID'
:
os
.
environ
.
get
(
'NVIDIA_PIPELINE_ID'
),
'NVIDIA_BUILD_ID'
:
os
.
environ
.
get
(
'NVIDIA_BUILD_ID'
),
'NVIDIA_TF32_OVERRIDE'
:
os
.
environ
.
get
(
'NVIDIA_TF32_OVERRIDE'
),
}
def
reset_perf_meters
():
for
backend
in
dllogger
.
GLOBAL_LOGGER
.
backends
:
if
isinstance
(
backend
,
AggregatorBackend
):
backend
.
reset_perf_meters
()
PyTorch/NLP/Transformer/fairseq/meters.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
time
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
class
TimeMeter
(
object
):
"""Computes the average occurrence of some event per second"""
def
__init__
(
self
,
init
=
0
):
self
.
reset
(
init
)
def
reset
(
self
,
init
=
0
):
self
.
init
=
init
self
.
start
=
time
.
time
()
self
.
n
=
0
self
.
last_update
=
time
.
time
()
def
update
(
self
,
val
=
1
):
self
.
n
+=
val
self
.
last_update
=
time
.
time
()
@
property
def
avg
(
self
):
return
self
.
n
/
self
.
elapsed_time
@
property
def
elapsed_time
(
self
):
return
self
.
init
+
(
time
.
time
()
-
self
.
start
)
@
property
def
u_avg
(
self
):
return
self
.
n
/
(
self
.
last_update
-
self
.
start
)
class
StopwatchMeter
(
object
):
"""Computes the sum/avg duration of some event in seconds"""
def
__init__
(
self
):
self
.
reset
()
self
.
intervals
=
[]
def
start
(
self
):
self
.
start_time
=
time
.
time
()
def
stop
(
self
,
n
=
1
):
if
self
.
start_time
is
not
None
:
delta
=
time
.
time
()
-
self
.
start_time
self
.
intervals
.
append
(
delta
)
self
.
sum
+=
delta
self
.
n
+=
n
self
.
start_time
=
None
def
reset
(
self
):
self
.
sum
=
0
self
.
n
=
0
self
.
start_time
=
None
self
.
intervals
=
[]
@
property
def
avg
(
self
):
return
self
.
sum
/
self
.
n
def
p
(
self
,
i
):
assert
i
<=
100
idx
=
int
(
len
(
self
.
intervals
)
*
i
/
100
)
return
sorted
(
self
.
intervals
)[
idx
]
PyTorch/NLP/Transformer/fairseq/models/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_incremental_decoder
import
FairseqIncrementalDecoder
# noqa: F401
MODEL_REGISTRY
=
{}
ARCH_MODEL_REGISTRY
=
{}
ARCH_CONFIG_REGISTRY
=
{}
def
build_model
(
args
):
return
ARCH_MODEL_REGISTRY
[
args
.
arch
].
build_model
(
args
)
def
register_model
(
name
):
"""Decorator to register a new model (e.g., LSTM)."""
def
register_model_cls
(
cls
):
if
name
in
MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate model ({})'
.
format
(
name
))
MODEL_REGISTRY
[
name
]
=
cls
return
cls
return
register_model_cls
def
register_model_architecture
(
model_name
,
arch_name
):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
def
register_model_arch_fn
(
fn
):
if
model_name
not
in
MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register model architecture for unknown model type ({})'
.
format
(
model_name
))
if
arch_name
in
ARCH_MODEL_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate model architecture ({})'
.
format
(
arch_name
))
if
not
callable
(
fn
):
raise
ValueError
(
'Model architecture must be callable ({})'
.
format
(
arch_name
))
ARCH_MODEL_REGISTRY
[
arch_name
]
=
MODEL_REGISTRY
[
model_name
]
ARCH_CONFIG_REGISTRY
[
arch_name
]
=
fn
return
fn
return
register_model_arch_fn
# automatically import any Python files in the models/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.models.'
+
module
)
PyTorch/NLP/Transformer/fairseq/models/fairseq_incremental_decoder.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.nn
as
nn
class
FairseqIncrementalDecoder
(
nn
.
Module
):
"""Base class for incremental decoders."""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
prev_output_tokens
,
encoder_out
,
incremental_state
=
None
):
raise
NotImplementedError
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder incremental state.
This should be called when the order of the input has changed from the
previous time step. A typical use case is beam search, where the input
order changes between time steps based on the selection of beams.
"""
def
apply_reorder_incremental_state
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'reorder_incremental_state'
):
module
.
reorder_incremental_state
(
incremental_state
,
new_order
,
)
self
.
apply
(
apply_reorder_incremental_state
)
def
set_beam_size
(
self
,
beam_size
):
"""Sets the beam size in the decoder and all children."""
if
getattr
(
self
,
'_beam_size'
,
-
1
)
!=
beam_size
:
def
apply_set_beam_size
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'set_beam_size'
):
module
.
set_beam_size
(
beam_size
)
self
.
apply
(
apply_set_beam_size
)
self
.
_beam_size
=
beam_size
PyTorch/NLP/Transformer/fairseq/models/fused_layer_norm.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.import math
import
math
import
torch
import
numbers
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
import
fused_layer_norm_cuda
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-6
):
self
.
normalized_shape
=
normalized_shape
self
.
eps
=
eps
def
forward
(
self
,
input
,
weight
,
bias
):
input_
=
input
.
contiguous
()
weight_
=
weight
.
contiguous
()
bias_
=
bias
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward_affine
(
input_
,
self
.
normalized_shape
,
weight_
,
bias_
,
self
.
eps
)
self
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
def
backward
(
self
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
self
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
self
.
normalized_shape
,
weight_
,
bias_
,
self
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
;
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-6
):
self
.
normalized_shape
=
normalized_shape
self
.
eps
=
eps
def
forward
(
self
,
input
):
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
self
.
normalized_shape
,
self
.
eps
)
self
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
def
backward
(
self
,
grad_output
):
input_
,
mean
,
invvar
=
self
.
saved_tensors
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
self
.
normalized_shape
,
self
.
eps
)
return
grad_input
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
(
normalized_shape
,
eps
)(
input
,
weight
,
bias
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
(
normalized_shape
,
eps
)(
input
)
class
FusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized\_shape}[0] \times \text{normalized\_shape}[1]
\times \ldots \times \text{normalized\_shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = nn.LayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = nn.LayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = nn.LayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
super
(
FusedLayerNorm
,
self
).
__init__
()
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
elementwise_affine
=
elementwise_affine
if
self
.
elementwise_affine
:
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
else
:
self
.
register_parameter
(
'weight'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
def
reset_parameters
(
self
):
if
self
.
elementwise_affine
:
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
self
.
elementwise_affine
:
return
FusedLayerNormAffineFunction
(
self
.
normalized_shape
,
self
.
eps
)(
input
,
self
.
weight
,
self
.
bias
)
else
:
return
FusedLayerNormFunction
(
self
.
normalized_shape
,
self
.
eps
)(
input
)
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
'elementwise_affine={elementwise_affine}'
.
format
(
**
self
.
__dict__
)
PyTorch/NLP/Transformer/fairseq/models/transformer.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch
import
Tensor
from
typing
import
Optional
,
Dict
from
fairseq.modules
import
(
LearnedPositionalEmbedding
,
MultiheadAttention
,
SinusoidalPositionalEmbedding
)
from
.
import
(
FairseqIncrementalDecoder
,
register_model
,
register_model_architecture
,
)
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
torch
.
set_printoptions
(
threshold
=
500000000
,
linewidth
=
1024
)
@
torch
.
jit
.
script
def
jit_dropout_add
(
x
,
residual
,
prob
,
is_training
):
# type: (Tensor, Tensor, float, bool) -> Tensor
out
=
F
.
dropout
(
x
,
p
=
prob
,
training
=
is_training
)
out
=
residual
+
out
return
out
@
torch
.
jit
.
script
def
jit_relu_dropout
(
x
,
prob
,
is_training
):
# type: (Tensor, float, bool) -> Tensor
out
=
F
.
threshold
(
x
,
0.
,
0.
)
out
=
F
.
dropout
(
out
,
p
=
prob
,
training
=
is_training
)
return
out
@
register_model
(
'transformer'
)
class
TransformerModel
(
nn
.
Module
):
@
staticmethod
def
add_args
(
parser
):
"""Add model-specific arguments to the parser."""
parser
.
add_argument
(
'--dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability'
)
parser
.
add_argument
(
'--attention-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability for attention weights'
)
parser
.
add_argument
(
'--relu-dropout'
,
type
=
float
,
metavar
=
'D'
,
help
=
'dropout probability after ReLU in FFN'
)
parser
.
add_argument
(
'--encoder-embed-path'
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained encoder embedding'
)
parser
.
add_argument
(
'--encoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension'
)
parser
.
add_argument
(
'--encoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'encoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--encoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder layers'
)
parser
.
add_argument
(
'--encoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num encoder attention heads'
)
parser
.
add_argument
(
'--encoder-normalize-before'
,
action
=
'store_true'
,
help
=
'apply layernorm before each encoder block'
)
parser
.
add_argument
(
'--encoder-learned-pos'
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the encoder'
)
parser
.
add_argument
(
'--decoder-embed-path'
,
type
=
str
,
metavar
=
'STR'
,
help
=
'path to pre-trained decoder embedding'
)
parser
.
add_argument
(
'--decoder-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension'
)
parser
.
add_argument
(
'--decoder-ffn-embed-dim'
,
type
=
int
,
metavar
=
'N'
,
help
=
'decoder embedding dimension for FFN'
)
parser
.
add_argument
(
'--decoder-layers'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder layers'
)
parser
.
add_argument
(
'--decoder-attention-heads'
,
type
=
int
,
metavar
=
'N'
,
help
=
'num decoder attention heads'
)
parser
.
add_argument
(
'--decoder-learned-pos'
,
action
=
'store_true'
,
help
=
'use learned positional embeddings in the decoder'
)
parser
.
add_argument
(
'--decoder-normalize-before'
,
action
=
'store_true'
,
help
=
'apply layernorm before each decoder block'
)
parser
.
add_argument
(
'--share-decoder-input-output-embed'
,
action
=
'store_true'
,
help
=
'share decoder input and output embeddings'
)
parser
.
add_argument
(
'--share-all-embeddings'
,
action
=
'store_true'
,
help
=
'share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)'
)
def
__init__
(
self
,
encoder
,
decoder
):
super
().
__init__
()
self
.
_is_generation_fast
=
False
self
.
encoder
=
encoder
self
.
decoder
=
decoder
@
classmethod
def
build_model
(
cls
,
args
):
# make sure all arguments are present in older models
base_architecture
(
args
)
if
not
hasattr
(
args
,
'max_source_positions'
):
args
.
max_source_positions
=
1024
if
not
hasattr
(
args
,
'max_target_positions'
):
args
.
max_target_positions
=
1024
if
args
.
share_all_embeddings
:
if
args
.
src_vocab_size
!=
args
.
tgt_vocab_size
:
raise
RuntimeError
(
'--share-all-embeddings requires a joined dictionary'
)
if
args
.
encoder_embed_dim
!=
args
.
decoder_embed_dim
:
raise
RuntimeError
(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim'
)
if
args
.
decoder_embed_path
and
(
args
.
decoder_embed_path
!=
args
.
encoder_embed_path
):
raise
RuntimeError
(
'--share-all-embeddings not compatible with --decoder-embed-path'
)
encoder_embed_tokens
=
Embedding
(
args
.
src_vocab_size
,
args
.
encoder_embed_dim
,
args
.
padding_idx
)
decoder_embed_tokens
=
encoder_embed_tokens
args
.
share_decoder_input_output_embed
=
True
else
:
encoder_embed_tokens
=
Embedding
(
args
.
src_vocab_size
,
args
.
encoder_embed_dim
,
args
.
padding_idx
)
decoder_embed_tokens
=
Embedding
(
args
.
tgt_vocab_size
,
args
.
decoder_embed_dim
,
args
.
padding_idx
)
encoder
=
TransformerEncoder
(
args
,
encoder_embed_tokens
)
decoder
=
TransformerDecoder
(
args
,
decoder_embed_tokens
)
return
TransformerModel
(
encoder
,
decoder
)
def
make_generation_fast_
(
self
,
**
kwargs
):
"""Optimize model for faster generation."""
if
self
.
_is_generation_fast
:
return
# only apply once
self
.
_is_generation_fast
=
True
# remove weight norm from all modules in the network
def
apply_remove_weight_norm
(
module
):
try
:
nn
.
utils
.
remove_weight_norm
(
module
)
except
ValueError
:
# this module didn't have weight norm
return
self
.
apply
(
apply_remove_weight_norm
)
def
apply_make_generation_fast_
(
module
):
if
module
!=
self
and
hasattr
(
module
,
'make_generation_fast_'
):
module
.
make_generation_fast_
(
**
kwargs
)
self
.
apply
(
apply_make_generation_fast_
)
def
train
(
mode
):
if
mode
:
raise
RuntimeError
(
'cannot train after make_generation_fast'
)
# this model should no longer be used for training
self
.
eval
()
self
.
train
=
train
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
):
encoder_out
,
padding_mask
=
self
.
encoder
(
src_tokens
,
src_lengths
)
decoder_out
=
self
.
decoder
(
prev_output_tokens
,
encoder_out
,
padding_mask
)
return
decoder_out
class
TransformerEncoder
(
nn
.
Module
):
"""Transformer encoder."""
def
__init__
(
self
,
args
,
embed_tokens
,
left_pad
=
True
):
super
().
__init__
()
self
.
dropout
=
args
.
dropout
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
embed_dim
=
embed_tokens
.
embedding_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_source_positions
=
args
.
max_source_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
,
left_pad
=
left_pad
,
learned
=
args
.
encoder_learned_pos
,
)
if
not
args
.
no_token_positional_embeddings
else
None
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
TransformerEncoderLayer
(
args
)
for
i
in
range
(
args
.
encoder_layers
)
])
self
.
normalize
=
args
.
encoder_normalize_before
if
self
.
normalize
:
self
.
layer_norm
=
FusedLayerNorm
(
embed_dim
)
if
args
.
fuse_layer_norm
else
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
src_tokens
,
src_lengths
):
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
src_tokens
)
if
self
.
embed_positions
is
not
None
:
x
+=
self
.
embed_positions
(
src_tokens
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
# The tensor needs to copy transposed because
# fused dropout is not capable of handing strided data
if
self
.
fuse_dropout_add
:
x
=
x
.
transpose
(
0
,
1
).
contiguous
()
else
:
x
=
x
.
transpose
(
0
,
1
)
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
if
not
encoder_padding_mask
.
any
():
_encoder_padding_mask
=
None
else
:
_encoder_padding_mask
=
encoder_padding_mask
# encoder layers
for
layer
in
self
.
layers
:
x
=
layer
(
x
,
_encoder_padding_mask
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
return
x
,
encoder_padding_mask
# x.shape == T x B x C, encoder_padding_mask.shape == B x T
def
reorder_encoder_out
(
self
,
encoder_out
,
encoder_padding_mask
,
new_order
):
if
encoder_out
is
not
None
:
encoder_out
=
encoder_out
.
index_select
(
1
,
new_order
)
if
encoder_padding_mask
is
not
None
:
encoder_padding_mask
=
encoder_padding_mask
.
index_select
(
0
,
new_order
)
return
encoder_out
,
encoder_padding_mask
class
TransformerDecoder
(
FairseqIncrementalDecoder
):
"""Transformer decoder."""
def
__init__
(
self
,
args
,
embed_tokens
,
no_encoder_attn
=
False
,
left_pad
=
False
):
super
().
__init__
()
self
.
dropout
=
args
.
dropout
self
.
share_input_output_embed
=
args
.
share_decoder_input_output_embed
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
embed_dim
=
embed_tokens
.
embedding_dim
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_target_positions
=
args
.
max_target_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
PositionalEmbedding
(
args
.
max_target_positions
,
embed_dim
,
padding_idx
,
left_pad
=
left_pad
,
learned
=
args
.
decoder_learned_pos
,
)
if
not
args
.
no_token_positional_embeddings
else
None
self
.
layers
=
nn
.
ModuleList
([])
self
.
layers
.
extend
([
TransformerDecoderLayer
(
args
,
no_encoder_attn
)
for
_
in
range
(
args
.
decoder_layers
)
])
if
not
self
.
share_input_output_embed
:
self
.
embed_out
=
nn
.
Parameter
(
torch
.
Tensor
(
args
.
tgt_vocab_size
,
embed_dim
))
nn
.
init
.
normal_
(
self
.
embed_out
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
else
:
self
.
embed_out
=
self
.
embed_tokens
.
weight
self
.
normalize
=
args
.
decoder_normalize_before
if
self
.
normalize
:
self
.
layer_norm
=
FusedLayerNorm
(
embed_dim
)
if
args
.
fuse_layer_norm
else
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
prev_output_tokens
:
Tensor
,
encoder_out
:
Tensor
,
encoder_padding_mask
:
Tensor
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]]
=
None
):
# embed positions
positions
=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
=
incremental_state
,
)
if
self
.
embed_positions
is
not
None
else
None
if
incremental_state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
if
positions
is
not
None
:
positions
=
positions
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
positions
is
not
None
:
x
+=
positions
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
# The tensor needs to copy transposed because
# fused dropout is not capable of handing strided data
if
self
.
fuse_dropout_add
:
x
=
x
.
transpose
(
0
,
1
).
contiguous
()
else
:
x
=
x
.
transpose
(
0
,
1
)
attn
=
None
# decoder layers
for
layer
in
self
.
layers
:
x
,
attn
=
layer
(
x
,
encoder_out
,
encoder_padding_mask
if
encoder_padding_mask
.
any
()
else
None
,
incremental_state
,
)
if
self
.
normalize
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
# project back to size of vocabulary
x
=
F
.
linear
(
x
,
self
.
embed_out
)
return
x
,
attn
class
TransformerEncoderLayer
(
nn
.
Module
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def
__init__
(
self
,
args
):
super
().
__init__
()
self
.
embed_dim
=
args
.
encoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
encoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
dropout
=
args
.
dropout
self
.
relu_dropout
=
args
.
relu_dropout
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
self
.
normalize_before
=
args
.
encoder_normalize_before
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
encoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
encoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
maybe_ln1
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
self
.
maybe_ln2
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
def
forward
(
self
,
x
:
Tensor
,
encoder_padding_mask
:
Optional
[
Tensor
]):
residual
=
x
x
=
self
.
maybe_ln1
(
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
mask_future_timesteps
=
False
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
None
,
need_weights
=
False
,
static_kv
=
False
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_ln1
(
x
,
after
=
True
)
residual
=
x
x
=
self
.
maybe_ln2
(
x
,
before
=
True
)
if
self
.
fuse_relu_dropout
:
x
=
jit_relu_dropout
(
self
.
fc1
(
x
),
self
.
relu_dropout
,
self
.
training
)
else
:
x
=
F
.
threshold
(
self
.
fc1
(
x
),
0.0
,
0.0
)
x
=
F
.
dropout
(
x
,
p
=
self
.
relu_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
maybe_ln2
(
x
,
after
=
True
)
return
x
class
TransformerDecoderLayer
(
nn
.
Module
):
"""Decoder layer block."""
def
__init__
(
self
,
args
,
no_encoder_attn
=
False
):
super
().
__init__
()
self
.
embed_dim
=
args
.
decoder_embed_dim
self
.
self_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
dropout
=
args
.
dropout
self
.
relu_dropout
=
args
.
relu_dropout
self
.
normalize_before
=
args
.
decoder_normalize_before
self
.
fuse_dropout_add
=
args
.
fuse_dropout_add
self
.
fuse_relu_dropout
=
args
.
fuse_relu_dropout
self
.
self_attn_layer_norm
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
if
no_encoder_attn
:
self
.
encoder_attn
=
None
self
.
encoder_attn_layer_norm
=
None
else
:
self
.
encoder_attn
=
MultiheadAttention
(
self
.
embed_dim
,
args
.
decoder_attention_heads
,
dropout
=
args
.
attention_dropout
,
)
self
.
encoder_attn_layer_norm
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
self
.
fc1
=
Linear
(
self
.
embed_dim
,
args
.
decoder_ffn_embed_dim
)
self
.
fc2
=
Linear
(
args
.
decoder_ffn_embed_dim
,
self
.
embed_dim
)
self
.
final_layer_norm
=
MaybeLayerNorm
(
self
.
embed_dim
,
self
.
normalize_before
,
fuse
=
args
.
fuse_layer_norm
)
self
.
need_attn
=
True
def
forward
(
self
,
x
:
Tensor
,
encoder_out
:
Tensor
,
encoder_padding_mask
:
Optional
[
Tensor
],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]]):
residual
=
x
x
=
self
.
self_attn_layer_norm
(
x
,
before
=
True
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
mask_future_timesteps
=
True
,
key_padding_mask
=
None
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
static_kv
=
False
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
,
after
=
True
)
attn
=
None
if
self
.
encoder_attn
is
not
None
:
residual
=
x
x
=
self
.
encoder_attn_layer_norm
(
x
,
before
=
True
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
mask_future_timesteps
=
False
,
need_weights
=
(
not
self
.
training
and
self
.
need_attn
),
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
encoder_attn_layer_norm
(
x
,
after
=
True
)
residual
=
x
x
=
self
.
final_layer_norm
(
x
,
before
=
True
)
if
self
.
fuse_relu_dropout
:
x
=
jit_relu_dropout
(
self
.
fc1
(
x
),
self
.
relu_dropout
,
self
.
training
)
else
:
x
=
F
.
threshold
(
self
.
fc1
(
x
),
0.0
,
0.0
)
x
=
F
.
dropout
(
x
,
p
=
self
.
relu_dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
if
self
.
fuse_dropout_add
and
self
.
training
:
x
=
jit_dropout_add
(
x
,
residual
,
self
.
dropout
,
self
.
training
)
else
:
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
residual
+
x
x
=
self
.
final_layer_norm
(
x
,
after
=
True
)
return
x
,
attn
def
make_generation_fast_
(
self
,
need_attn
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
class
MaybeLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
embed_dim
,
normalize_before
,
fuse
=
True
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
normalize_before
=
normalize_before
self
.
ln
=
FusedLayerNorm
(
embed_dim
)
if
fuse
else
nn
.
LayerNorm
(
embed_dim
)
def
forward
(
self
,
x
:
Tensor
,
before
:
bool
=
False
,
after
:
bool
=
False
):
assert
before
^
after
if
after
^
self
.
normalize_before
:
return
self
.
ln
(
x
)
else
:
return
x
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
constant_
(
m
.
bias
,
0.
)
return
m
def
PositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
,
learned
=
False
):
if
learned
:
m
=
LearnedPositionalEmbedding
(
num_embeddings
+
padding_idx
+
1
,
embedding_dim
,
padding_idx
,
left_pad
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**
-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
else
:
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
left_pad
,
num_embeddings
+
padding_idx
+
1
)
return
m
@
register_model_architecture
(
'transformer'
,
'transformer'
)
def
base_architecture
(
args
):
args
.
encoder_embed_path
=
getattr
(
args
,
'encoder_embed_path'
,
None
)
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
2048
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
6
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
8
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
False
)
args
.
encoder_learned_pos
=
getattr
(
args
,
'encoder_learned_pos'
,
False
)
args
.
decoder_embed_path
=
getattr
(
args
,
'decoder_embed_path'
,
None
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
6
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
8
)
args
.
decoder_normalize_before
=
getattr
(
args
,
'decoder_normalize_before'
,
False
)
args
.
decoder_learned_pos
=
getattr
(
args
,
'decoder_learned_pos'
,
False
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.
)
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.1
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
'share_decoder_input_output_embed'
,
False
)
args
.
share_all_embeddings
=
getattr
(
args
,
'share_all_embeddings'
,
False
)
args
.
no_token_positional_embeddings
=
getattr
(
args
,
'no_token_positional_embeddings'
,
False
)
@
register_model_architecture
(
'transformer'
,
'transformer_iwslt_de_en'
)
def
transformer_iwslt_de_en
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
512
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
1024
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
4
)
args
.
encoder_layers
=
getattr
(
args
,
'encoder_layers'
,
6
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
512
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
1024
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
4
)
args
.
decoder_layers
=
getattr
(
args
,
'decoder_layers'
,
6
)
base_architecture
(
args
)
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de'
)
def
transformer_wmt_en_de
(
args
):
base_architecture
(
args
)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@
register_model_architecture
(
'transformer'
,
'transformer_vaswani_wmt_en_de_big'
)
def
transformer_vaswani_wmt_en_de_big
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
'encoder_embed_dim'
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
'encoder_ffn_embed_dim'
,
4096
)
args
.
encoder_attention_heads
=
getattr
(
args
,
'encoder_attention_heads'
,
16
)
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
False
)
args
.
decoder_embed_dim
=
getattr
(
args
,
'decoder_embed_dim'
,
1024
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
'decoder_ffn_embed_dim'
,
4096
)
args
.
decoder_attention_heads
=
getattr
(
args
,
'decoder_attention_heads'
,
16
)
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.3
)
base_architecture
(
args
)
@
register_model_architecture
(
'transformer'
,
'transformer_vaswani_wmt_en_fr_big'
)
def
transformer_vaswani_wmt_en_fr_big
(
args
):
args
.
dropout
=
getattr
(
args
,
'dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big'
)
def
transformer_wmt_en_de_big
(
args
):
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
# default parameters used in tensor2tensor implementation
@
register_model_architecture
(
'transformer'
,
'transformer_wmt_en_de_big_t2t'
)
def
transformer_wmt_en_de_big_t2t
(
args
):
args
.
encoder_normalize_before
=
getattr
(
args
,
'encoder_normalize_before'
,
True
)
args
.
decoder_normalize_before
=
getattr
(
args
,
'decoder_normalize_before'
,
True
)
args
.
attention_dropout
=
getattr
(
args
,
'attention_dropout'
,
0.1
)
args
.
relu_dropout
=
getattr
(
args
,
'relu_dropout'
,
0.1
)
transformer_vaswani_wmt_en_de_big
(
args
)
PyTorch/NLP/Transformer/fairseq/modules/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.beamable_mm
import
BeamableMM
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.multihead_attention
import
MultiheadAttention
from
.sinusoidal_positional_embedding
import
SinusoidalPositionalEmbedding
__all__
=
[
'BeamableMM'
,
'LearnedPositionalEmbedding'
,
'MultiheadAttention'
,
'SinusoidalPositionalEmbedding'
,
]
PyTorch/NLP/Transformer/fairseq/modules/learned_positional_embedding.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.nn
as
nn
from
fairseq
import
utils
class
LearnedPositionalEmbedding
(
nn
.
Embedding
):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
def
__init__
(
self
,
num_embeddings
,
embedding_dim
,
padding_idx
,
left_pad
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
left_pad
=
left_pad
def
forward
(
self
,
input
,
incremental_state
=
None
):
"""Input is expected to be of size [bsz x seqlen]."""
if
incremental_state
is
not
None
:
# positions is the same for every token when decoding a single step
positions
=
input
.
data
.
new
(
1
,
1
).
fill_
(
self
.
padding_idx
+
input
.
size
(
1
))
else
:
positions
=
utils
.
make_positions
(
input
.
data
,
self
.
padding_idx
,
self
.
left_pad
)
return
super
().
forward
(
positions
)
PyTorch/NLP/Transformer/fairseq/modules/multihead_attention.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
,
Optional
import
torch
from
torch
import
nn
,
Tensor
from
torch.nn
import
Parameter
import
torch.nn.functional
as
F
from
torch.cuda
import
amp
from
torch.autograd.variable
import
Variable
import
strided_batched_gemm
from
fairseq
import
utils
class
QueryLinear
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input
,
weights_q
,
scale
):
s
=
Variable
(
torch
.
tensor
([
scale
]))
ctx
.
save_for_backward
(
input
,
weights_q
,
s
)
q
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_q
,
beta
=
0.0
,
alpha
=
s
[
0
])
q
=
q
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
return
q
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
q_grad
):
input
,
weights_q
,
s
=
ctx
.
saved_tensors
input
=
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)).
transpose
(
0
,
1
)
q
=
torch
.
addmm
(
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
weights_q
.
transpose
(
0
,
1
),
beta
=
0.0
,
alpha
=
s
[
0
])
q
=
q
.
view
(
q_grad
.
size
(
0
),
q_grad
.
size
(
1
),
q_grad
.
size
(
2
))
q_grad
=
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
))
weights_q_grad
=
torch
.
addmm
(
weights_q
,
input
,
q_grad
,
beta
=
0.0
,
alpha
=
s
[
0
])
return
q
,
weights_q_grad
,
None
class
KeyValueLinears
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input
,
weights_k
,
weights_v
):
ctx
.
save_for_backward
(
input
,
weights_k
,
weights_v
)
k
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_k
,
beta
=
0.0
,
alpha
=
1.0
)
k
=
k
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
v
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_v
,
beta
=
0.0
,
alpha
=
1.0
)
v
=
v
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
return
k
.
detach
(),
v
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
k_grad
,
v_grad
):
input
,
weights_k
,
weights_v
=
ctx
.
saved_tensors
input
=
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)).
transpose
(
0
,
1
)
k
=
torch
.
addmm
(
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
)),
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
)),
weights_k
.
transpose
(
0
,
1
),
beta
=
0.0
)
k_grad
=
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
))
weights_k_grad
=
torch
.
mm
(
input
,
k_grad
)
v
=
k
.
addmm_
(
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
)),
weights_v
.
transpose
(
0
,
1
),
beta
=
1.0
)
v
=
v
.
view
(
v_grad
.
size
(
0
),
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
v_grad
=
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
weights_v_grad
=
torch
.
mm
(
input
,
v_grad
)
return
v
,
weights_k_grad
,
weights_v_grad
class
SelfAttentionLinears
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input
,
weights_q
,
weights_k
,
weights_v
,
scale
):
s
=
Variable
(
torch
.
tensor
([
scale
]))
ctx
.
save_for_backward
(
input
,
weights_q
,
weights_k
,
weights_v
,
s
)
q
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_q
,
beta
=
0.0
,
alpha
=
s
[
0
])
q
=
q
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
k
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_k
,
beta
=
0.0
,
alpha
=
1.0
)
k
=
k
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
v
=
torch
.
addmm
(
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)),
weights_v
,
beta
=
0.0
,
alpha
=
1.0
)
v
=
v
.
view
(
input
.
size
(
0
),
input
.
size
(
1
),
input
.
size
(
2
))
return
q
.
detach
(),
k
.
detach
(),
v
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
q_grad
,
k_grad
,
v_grad
):
input
,
weights_q
,
weights_k
,
weights_v
,
s
=
ctx
.
saved_tensors
input
=
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
input
.
size
(
2
)).
transpose
(
0
,
1
)
q
=
torch
.
addmm
(
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
)),
weights_q
.
transpose
(
0
,
1
),
beta
=
0.0
,
alpha
=
s
[
0
])
q_grad
=
q_grad
.
view
(
q_grad
.
size
(
0
)
*
q_grad
.
size
(
1
),
q_grad
.
size
(
2
))
weights_q_grad
=
torch
.
addmm
(
weights_q
,
input
,
q_grad
,
beta
=
0.0
,
alpha
=
s
[
0
])
k
=
q
.
addmm_
(
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
)),
weights_k
.
transpose
(
0
,
1
),
beta
=
1.0
)
k_grad
=
k_grad
.
view
(
k_grad
.
size
(
0
)
*
k_grad
.
size
(
1
),
k_grad
.
size
(
2
))
weights_k_grad
=
torch
.
mm
(
input
,
k_grad
)
v
=
k
.
addmm_
(
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
)),
weights_v
.
transpose
(
0
,
1
),
beta
=
1.0
)
v
=
v
.
view
(
v_grad
.
size
(
0
),
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
v_grad
=
v_grad
.
view
(
v_grad
.
size
(
0
)
*
v_grad
.
size
(
1
),
v_grad
.
size
(
2
))
weights_v_grad
=
torch
.
mm
(
input
,
v_grad
)
return
v
,
weights_q_grad
,
weights_k_grad
,
weights_v_grad
,
None
class
StridedBmm1Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input1
,
input2
):
ctx
.
save_for_backward
(
input1
,
input2
)
output
=
torch
.
empty
((
input1
.
size
(
0
),
input1
.
size
(
1
),
input2
.
size
(
2
)),
dtype
=
input1
.
dtype
,
device
=
torch
.
device
(
'cuda'
))
if
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
output
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
output
,
1.0
,
input1
,
input2
)
else
:
output
=
torch
.
bmm
(
input1
,
input2
,
out
=
output
)
return
output
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
grad_output
):
input1
,
input2
=
ctx
.
saved_tensors
grad_input1
=
torch
.
empty
((
input1
.
size
(
1
),
input2
.
size
(
0
),
input1
.
size
(
2
)),
dtype
=
input1
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
grad_input2
=
torch
.
empty
((
input2
.
size
(
2
),
input2
.
size
(
0
),
input2
.
size
(
1
)),
dtype
=
input2
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
if
(
grad_output
.
dtype
==
torch
.
float16
)
and
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
grad_input1
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input1
,
1.0
,
grad_output
,
input2
.
transpose
(
1
,
2
))
grad_input2
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input2
,
1.0
,
grad_output
.
transpose
(
1
,
2
),
input1
)
grad_input2
=
grad_input2
.
transpose
(
1
,
2
)
else
:
grad_input1
=
torch
.
bmm
(
grad_output
,
input2
.
transpose
(
1
,
2
),
out
=
grad_input1
)
grad_input2
=
torch
.
bmm
(
grad_output
.
transpose
(
1
,
2
),
input1
,
out
=
grad_input2
).
transpose
(
1
,
2
)
return
grad_input1
,
grad_input2
class
StridedBmm2Func
(
torch
.
autograd
.
Function
):
@
staticmethod
@
amp
.
custom_fwd
(
cast_inputs
=
torch
.
half
)
def
forward
(
ctx
,
input1
,
input2
):
ctx
.
save_for_backward
(
input1
,
input2
)
output
=
torch
.
empty
((
input1
.
size
(
1
),
input1
.
size
(
0
),
input2
.
size
(
2
)),
dtype
=
input1
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
if
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
output
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
output
,
1.0
,
input1
,
input2
)
else
:
output
=
torch
.
bmm
(
input1
,
input2
,
out
=
output
)
return
output
.
detach
()
@
staticmethod
@
amp
.
custom_bwd
def
backward
(
ctx
,
grad_output
):
input1
,
input2
=
ctx
.
saved_tensors
grad_input2
=
torch
.
empty
((
input2
.
size
(
1
),
input2
.
size
(
0
),
input2
.
size
(
2
)),
dtype
=
input2
.
dtype
,
device
=
torch
.
device
(
'cuda'
)).
transpose
(
1
,
0
)
grad_input1
=
torch
.
empty
((
input1
.
size
(
0
),
input1
.
size
(
1
),
input1
.
size
(
2
)),
dtype
=
input2
.
dtype
,
device
=
torch
.
device
(
'cuda'
))
if
(
grad_output
.
dtype
==
torch
.
float16
)
and
(
input1
.
dtype
==
torch
.
float16
)
and
(
input2
.
dtype
==
torch
.
float16
):
grad_input1
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input1
,
1.0
,
grad_output
,
input2
.
transpose
(
1
,
2
))
grad_input2
=
strided_batched_gemm
.
strided_batched_gemm
(
0.0
,
grad_input2
,
1.0
,
input1
.
transpose
(
1
,
2
),
grad_output
)
else
:
grad_input1
=
torch
.
bmm
(
grad_output
,
input2
.
transpose
(
1
,
2
))
grad_input2
=
torch
.
bmm
(
input1
.
transpose
(
1
,
2
),
grad_output
,
out
=
grad_input2
)
return
grad_input1
,
grad_input2
def
query_linear
(
input
:
Tensor
,
weights_q
:
Tensor
,
scale
:
float
):
if
not
torch
.
jit
.
is_scripting
():
return
QueryLinear
.
apply
(
input
,
weights_q
,
scale
)
else
:
q
=
scale
*
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_q
)
q
=
q
.
view
(
input
.
shape
)
return
q
def
key_value_linears
(
input
:
Tensor
,
weights_k
:
Tensor
,
weights_v
:
Tensor
):
if
not
torch
.
jit
.
is_scripting
():
return
KeyValueLinears
.
apply
(
input
,
weights_k
,
weights_v
)
else
:
k
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_k
)
k
=
k
.
view
(
input
.
shape
)
v
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_v
)
v
=
v
.
view
(
input
.
shape
)
return
k
,
v
def
self_attn_linears
(
input
:
Tensor
,
weights_q
:
Tensor
,
weights_k
:
Tensor
,
weights_v
:
Tensor
,
scale
:
float
):
if
not
torch
.
jit
.
is_scripting
():
return
SelfAttentionLinears
.
apply
(
input
,
weights_q
,
weights_k
,
weights_v
,
scale
)
else
:
q
=
scale
*
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_q
)
q
=
q
.
view
(
input
.
shape
)
k
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_k
)
k
=
k
.
view
(
input
.
shape
)
v
=
torch
.
einsum
(
'ij,jk->ik'
,
input
.
view
(
input
.
size
(
0
)
*
input
.
size
(
1
),
-
1
),
weights_v
)
v
=
v
.
view
(
input
.
shape
)
return
q
,
k
,
v
def
strided_bmm1
(
input1
:
Tensor
,
input2
:
Tensor
):
if
not
torch
.
jit
.
is_scripting
():
return
StridedBmm1Func
.
apply
(
input1
,
input2
)
else
:
return
torch
.
einsum
(
'ijk,ikn->ijn'
,
input1
,
input2
)
def
strided_bmm2
(
input1
:
Tensor
,
input2
:
Tensor
):
if
not
torch
.
jit
.
is_scripting
():
return
StridedBmm2Func
.
apply
(
input1
,
input2
)
else
:
return
torch
.
einsum
(
'ijk,ikn->ijn'
,
input1
,
input2
)
class
MultiheadAttention
(
nn
.
Module
):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
dropout
=
0.
,
bias
=
False
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
num_heads
=
num_heads
self
.
dropout
=
dropout
self
.
head_dim
=
embed_dim
//
num_heads
assert
self
.
head_dim
*
num_heads
==
self
.
embed_dim
,
"embed_dim must be divisible by num_heads"
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
_mask
=
torch
.
empty
(
0
)
#self.in_proj_weight = Parameter(torch.Tensor(3*embed_dim, embed_dim))
self
.
in_proj_weight_q
=
Parameter
(
torch
.
Tensor
(
embed_dim
,
embed_dim
))
self
.
in_proj_weight_k
=
Parameter
(
torch
.
Tensor
(
embed_dim
,
embed_dim
))
self
.
in_proj_weight_v
=
Parameter
(
torch
.
Tensor
(
embed_dim
,
embed_dim
))
if
bias
:
#self.in_proj_bias = Parameter(torch.Tensor(3*embed_dim))
self
.
in_proj_bias_q
=
Parameter
(
torch
.
Tensor
(
embed_dim
))
self
.
in_proj_bias_k
=
Parameter
(
torch
.
Tensor
(
embed_dim
))
self
.
in_proj_bias_v
=
Parameter
(
torch
.
Tensor
(
embed_dim
))
else
:
#self.register_parameter('in_proj_bias', None)
self
.
register_parameter
(
'in_proj_bias_k'
,
None
)
self
.
register_parameter
(
'in_proj_bias_q'
,
None
)
self
.
register_parameter
(
'in_proj_bias_v'
,
None
)
self
.
out_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
bias
)
self
.
cache_id
=
str
(
id
(
self
))
self
.
reset_parameters
()
def
reset_parameters
(
self
):
#nn.init.xavier_uniform_(self.in_proj_weight)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_q
)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_k
)
nn
.
init
.
xavier_uniform_
(
self
.
in_proj_weight_v
)
nn
.
init
.
xavier_uniform_
(
self
.
out_proj
.
weight
)
if
self
.
in_proj_bias_k
is
not
None
:
#nn.init.constant_(self.in_proj_bias, 0.)
nn
.
init
.
constant_
(
self
.
in_proj_bias_q
,
0.
)
nn
.
init
.
constant_
(
self
.
in_proj_bias_k
,
0.
)
nn
.
init
.
constant_
(
self
.
in_proj_bias_v
,
0.
)
nn
.
init
.
constant_
(
self
.
out_proj
.
bias
,
0.
)
def
forward
(
self
,
query
:
Tensor
,
key
:
Tensor
,
value
:
Tensor
,
mask_future_timesteps
:
bool
,
key_padding_mask
:
Optional
[
Tensor
],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]],
need_weights
:
bool
,
static_kv
:
bool
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
if
torch
.
jit
.
is_scripting
():
kv_same
=
torch
.
equal
(
key
,
value
)
qkv_same
=
torch
.
equal
(
query
,
value
)
and
kv_same
else
:
qkv_same
,
kv_same
=
self
.
_fast_same_check
(
query
,
key
,
value
)
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
assert
embed_dim
==
self
.
embed_dim
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
assert
key
.
size
()
==
value
.
size
()
k
=
v
=
query
.
new_empty
(
0
)
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
else
:
saved_state
=
None
if
qkv_same
:
# self-attention
q
,
k
,
v
=
self_attn_linears
(
query
,
self
.
in_proj_weight_q
,
self
.
in_proj_weight_k
,
self
.
in_proj_weight_v
,
self
.
scaling
)
elif
kv_same
:
# encoder-decoder attention
q
=
query_linear
(
query
,
self
.
in_proj_weight_q
,
self
.
scaling
)
if
not
(
saved_state
is
not
None
and
'prev_key'
in
saved_state
and
static_kv
):
k
,
v
=
key_value_linears
(
key
,
self
.
in_proj_weight_k
,
self
.
in_proj_weight_v
)
else
:
q
=
torch
.
addmm
(
query
.
view
(
query
.
size
(
0
)
*
query
.
size
(
1
),
query
.
size
(
2
)),
query
.
view
(
query
.
size
(
0
)
*
query
.
size
(
1
),
query
.
size
(
2
)),
self
.
in_proj_weight_q
,
beta
=
0.0
,
alpha
=
self
.
scaling
)
if
not
(
saved_state
is
not
None
and
'prev_key'
in
saved_state
and
static_kv
):
k
=
F
.
linear
(
key
,
self
.
in_proj_weight_k
,
self
.
in_proj_bias_k
)
v
=
F
.
linear
(
value
,
self
.
in_proj_weight_v
,
self
.
in_proj_bias_v
)
if
saved_state
is
not
None
:
if
'prev_key'
in
saved_state
:
k
=
torch
.
cat
((
saved_state
[
'prev_key'
],
k
),
dim
=
0
)
if
'prev_value'
in
saved_state
:
v
=
torch
.
cat
((
saved_state
[
'prev_value'
],
v
),
dim
=
0
)
saved_state
[
'prev_key'
]
=
k
saved_state
[
'prev_value'
]
=
v
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
src_len
=
k
.
size
(
0
)
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
q
=
q
.
contiguous
().
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
k
=
k
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
v
=
v
.
contiguous
().
view
(
src_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
).
transpose
(
0
,
1
)
attn_weights
=
strided_bmm1
(
q
,
k
.
transpose
(
1
,
2
))
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
# only apply masking at training time (when incremental state is None)
if
mask_future_timesteps
and
incremental_state
is
None
:
assert
query
.
size
()
==
key
.
size
(),
\
'mask_future_timesteps only applies to self-attention'
attn_weights
+=
self
.
buffered_mask
(
attn_weights
).
unsqueeze
(
0
)
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
float
().
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
),
float
(
'-inf'
),
).
type_as
(
attn_weights
)
# FP16 support: cast to float and back
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
F
.
softmax
(
attn_weights
,
dim
=-
1
)
attn_weights
=
F
.
dropout
(
attn_weights
,
p
=
self
.
dropout
,
training
=
self
.
training
)
attn
=
strided_bmm2
(
attn_weights
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
if
need_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
attn_weights
=
attn_weights
.
sum
(
dim
=
1
)
/
self
.
num_heads
else
:
attn_weights
=
attn_weights
.
new_empty
(
0
)
# Can't set to None because jit script reasons
return
attn
,
attn_weights
def
in_proj_qkv
(
self
,
query
):
return
self
.
_in_proj
(
query
).
chunk
(
3
,
dim
=-
1
)
def
in_proj_kv
(
self
,
key
):
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
).
chunk
(
2
,
dim
=-
1
)
def
in_proj_q
(
self
,
query
):
return
self
.
_in_proj
(
query
,
end
=
self
.
embed_dim
)
def
in_proj_k
(
self
,
key
):
return
self
.
_in_proj
(
key
,
start
=
self
.
embed_dim
,
end
=
2
*
self
.
embed_dim
)
def
in_proj_v
(
self
,
value
):
return
self
.
_in_proj
(
value
,
start
=
2
*
self
.
embed_dim
)
def
_in_proj
(
self
,
input
,
start
=
None
,
end
=
None
):
weight
=
self
.
in_proj_weight
bias
=
self
.
in_proj_bias
if
end
is
not
None
:
weight
=
weight
[:
end
,
:]
if
bias
is
not
None
:
bias
=
bias
[:
end
]
if
start
is
not
None
:
weight
=
weight
[
start
:,
:]
if
bias
is
not
None
:
bias
=
bias
[
start
:]
return
F
.
linear
(
input
,
weight
,
bias
)
def
buffered_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
-
1
)
if
self
.
_mask
.
size
(
0
)
==
0
:
#TODO: try torch.new_full instead
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
tensor
.
new_empty
(
dim
,
dim
)),
1
)
if
self
.
_mask
.
size
(
0
)
<
dim
:
self
.
_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
self
.
_mask
.
resize_
(
dim
,
dim
)),
1
)
return
self
.
_mask
[:
dim
,
:
dim
]
def
reorder_incremental_state
(
self
,
incremental_state
,
new_order
):
"""Reorder buffered internal state (for incremental generation)."""
input_buffer
=
self
.
_get_input_buffer
(
incremental_state
)
if
input_buffer
is
not
None
:
for
k
in
input_buffer
.
keys
():
input_buffer
[
k
]
=
input_buffer
[
k
].
index_select
(
1
,
new_order
)
self
.
_set_input_buffer
(
incremental_state
,
input_buffer
)
def
_get_input_buffer
(
self
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]]):
if
incremental_state
is
None
or
self
.
cache_id
not
in
incremental_state
:
return
{}
return
incremental_state
[
self
.
cache_id
]
def
_set_input_buffer
(
self
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Tensor
]]],
buffer
:
Dict
[
str
,
Tensor
]):
if
incremental_state
is
not
None
:
incremental_state
[
self
.
cache_id
]
=
buffer
@
torch
.
jit
.
unused
def
_fast_same_check
(
self
,
q
,
k
,
v
):
qkv_same
=
q
.
data_ptr
()
==
k
.
data_ptr
()
==
v
.
data_ptr
()
kv_same
=
k
.
data_ptr
()
==
v
.
data_ptr
()
return
qkv_same
,
kv_same
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm.cpp
deleted
100644 → 0
View file @
c056df78
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <torch/torch.h>
#include <vector>
at
::
Tensor
strided_batched_gemm_cuda
(
float
beta
,
at
::
Tensor
in_result
,
float
alpha
,
at
::
Tensor
batch1
,
at
::
Tensor
batch2
);
// C++ interface
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
at
::
Tensor
strided_batched_gemm
(
float
beta
,
at
::
Tensor
in_result
,
float
alpha
,
at
::
Tensor
batch1
,
at
::
Tensor
batch2
)
{
//CHECK_INPUT(in_result);
//CHECK_INPUT(batch1);
//CHECK_INPUT(batch2);
AT_ASSERTM
(
in_result
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
batch1
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
batch2
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
in_result
.
size
(
0
)
==
batch1
.
size
(
0
),
"equal number of batches expected"
);
AT_ASSERTM
(
in_result
.
size
(
0
)
==
batch2
.
size
(
0
),
"equal number of batches expected"
);
AT_ASSERTM
(
in_result
.
size
(
1
)
==
batch1
.
size
(
1
),
"wrong matrix size"
);
AT_ASSERTM
(
in_result
.
size
(
2
)
==
batch2
.
size
(
2
),
"wrong matrix size"
);
AT_ASSERTM
(
batch1
.
size
(
2
)
==
batch2
.
size
(
1
),
"wrong matrix size"
);
AT_ASSERTM
(
batch1
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
batch2
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
AT_ASSERTM
(
in_result
.
type
().
scalarType
()
==
at
::
ScalarType
::
Half
,
"Only HALF is supported"
);
return
strided_batched_gemm_cuda
(
beta
,
in_result
,
alpha
,
batch1
,
batch2
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_batched_gemm"
,
&
strided_batched_gemm
,
"Special strided batched gemm."
);
}
PyTorch/NLP/Transformer/fairseq/modules/strided_batched_gemm/strided_batched_gemm_cuda.cu
deleted
100644 → 0
View file @
c056df78
// Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "THC/THC.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/wmma_gemm_traits.h"
// symbol to be automatically resolved by PyTorch libs
extern
THCState
*
state
;
cublasOperation_t
convertTransToCublasOperation
(
char
trans
)
{
if
(
trans
==
't'
)
return
CUBLAS_OP_T
;
else
if
(
trans
==
'n'
)
return
CUBLAS_OP_N
;
else
if
(
trans
==
'c'
)
return
CUBLAS_OP_C
;
else
{
THError
(
"trans must be one of: t, n, c"
);
return
CUBLAS_OP_T
;
}
}
void
CublasGemm
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
cublasOperation_t
opa
=
convertTransToCublasOperation
(
transa
);
cublasOperation_t
opb
=
convertTransToCublasOperation
(
transb
);
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
//cublasSetStream(handle, THCState_getCurrentStream(state));
float
fAlpha
=
alpha
;
float
fBeta
=
beta
;
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_TENSOR_OP_MATH
));
THCublasCheck
(
cublasGemmStridedBatchedEx
(
handle
,
opa
,
opb
,
(
int
)
m
,
(
int
)
n
,
(
int
)
k
,
(
void
*
)
&
fAlpha
,
a
,
CUDA_R_16F
,
(
int
)
lda
,
strideA
,
b
,
CUDA_R_16F
,
(
int
)
ldb
,
strideB
,
(
void
*
)
&
fBeta
,
c
,
CUDA_R_16F
,
(
int
)
ldc
,
strideC
,
(
int
)
batchCount
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
));
THCublasCheck
(
cublasSetMathMode
(
handle
,
CUBLAS_DEFAULT_MATH
));
}
template
<
cutlass
::
MatrixLayout
::
Kind
A_LAYOUT
,
cutlass
::
MatrixLayout
::
Kind
B_LAYOUT
,
int
SRC_A
,
int
SRC_B
,
int
DST_C
>
void
CutlassGemm_FP32Accum
(
cudaStream_t
stream
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
//printf("CUTLASS-> %c%c M: %ld N: %ld K: %ld %d%d%d LDA: %ld LDB: %ld LDC: %ld strideA: %ld strideB: %ld strideC: %ld Alpha: %f Beta: %f\n", ((int)A_LAYOUT == 0 ? 'T' : 'N'), ((int)B_LAYOUT ==0 ? 'T' : 'N'), m, n, k, SRC_A,SRC_B,DST_C, lda, ldb, ldc, strideA, strideB, strideC, alpha, beta);
typedef
cutlass
::
gemm
::
WmmaGemmTraits
<
A_LAYOUT
,
B_LAYOUT
,
cutlass
::
Shape
<
32
,
16
,
16
>
,
half
,
half
,
half
,
cutlass
::
gemm
::
LinearScaling
<
float
>
,
float
,
typename
cutlass
::
gemm
::
WmmaGemmAccumulatorsPerWarp
<
typename
cutlass
::
Shape
<
32
,
16
,
16
>
>::
Shape
,
typename
cutlass
::
Shape
<
16
,
16
,
16
>
,
SRC_A
,
//kScalarsPerLdgA_
SRC_B
,
//kScalarsPerLdgB_
SRC_A
,
//KScalarsPerLdsA_
SRC_B
,
//KScalarsPerLdsB_
DST_C
,
//kScalarsPerLdgCAndStgD_
DST_C
/
2
,
//kScalarsPerStsD_
DST_C
/
2
//kScalarsPerLdsD_
>
WmmaGemmTraits
;
typedef
cutlass
::
gemm
::
Gemm
<
WmmaGemmTraits
>
Gemm
;
typename
Gemm
::
Params
params
;
int
result
=
params
.
initialize
(
m
,
// M dimension for each batch
n
,
// N dimension for each batch
k
,
// K dimension for each batch
alpha
,
// scalar alpha
a
,
lda
,
strideA
,
// distance in memory between the first element of neighboring batch
b
,
ldb
,
strideB
,
// distance in memory between the first element of neighboring batch
beta
,
// scalar beta
c
,
// source matrix C
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
c
,
// destination matrix C (may be different memory than source C matrix)
ldc
,
strideC
,
// distance in memory between the first element of neighboring batch
batchCount
);
AT_ASSERTM
(
result
==
0
,
"Failed to initialize CUTLASS Gemm::Params object."
);
// Launch the CUTLASS GEMM kernel.
THCudaCheck
(
Gemm
::
launch
(
params
));
}
void
gemm_switch_fp32accum
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
//cudaStream_t stream = THCState_getCurrentStream(state);
//printf("GEMM -> %c%c M: %i N: %i K: %i Alpha: %f Beta: %f\n", (transa == 't' ? 'T' : 'N'), (transb =='t' ? 'T' : 'N'), m, n, k, alpha, beta);
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
();
if
(
(
transa
==
't'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kRowMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
'n'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kColumnMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
if
(
(
transa
==
'n'
)
&&
(
transb
==
't'
)
)
{
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x7
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
8
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x3
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
4
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x7
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
8
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x3
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
4
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x7
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
8
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x3
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
4
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
if
(
!
(
lda
&
0x1
)
&&
!
(
ldb
&
0x1
)
&&
!
(
ldc
&
0x1
))
{
CutlassGemm_FP32Accum
<
cutlass
::
MatrixLayout
::
kColumnMajor
,
cutlass
::
MatrixLayout
::
kRowMajor
,
2
,
2
,
2
>
(
stream
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
else
{
CublasGemm
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
}
else
{
AT_ASSERTM
(
false
,
"TransA and TransB are invalid"
);
}
}
void
adjustLdLevel3
(
char
transa
,
char
transb
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
int64_t
*
lda
,
int64_t
*
ldb
,
int64_t
*
ldc
)
{
int
transa_
=
((
transa
==
't'
)
||
(
transa
==
'T'
));
int
transb_
=
((
transb
==
't'
)
||
(
transb
==
'T'
));
// Note: leading dimensions generally are checked that they are > 0 and at least as big the result
// requires (even if the value won't be used).
if
(
n
<=
1
)
*
ldc
=
std
::
max
<
int64_t
>
(
m
,
1
);
if
(
transa_
)
{
if
(
m
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
else
{
if
(
k
<=
1
)
*
lda
=
std
::
max
<
int64_t
>
(
m
,
1
);
}
if
(
transb_
)
{
if
(
k
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
n
,
1
);
}
else
{
if
(
n
<=
1
)
*
ldb
=
std
::
max
<
int64_t
>
(
k
,
1
);
}
}
void
HgemmStridedBatched
(
THCState
*
state
,
char
transa
,
char
transb
,
long
m
,
long
n
,
long
k
,
float
alpha
,
const
half
*
a
,
long
lda
,
long
strideA
,
const
half
*
b
,
long
ldb
,
long
strideB
,
float
beta
,
half
*
c
,
long
ldc
,
long
strideC
,
long
batchCount
)
{
if
(
(
m
>=
INT_MAX
)
||
(
n
>=
INT_MAX
)
||
(
k
>=
INT_MAX
)
||
(
lda
>=
INT_MAX
)
||
(
ldb
>=
INT_MAX
)
||
(
ldc
>=
INT_MAX
)
||
(
batchCount
>=
INT_MAX
)
)
{
THError
(
"Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d"
,
INT_MAX
);
}
adjustLdLevel3
(
transa
,
transb
,
m
,
n
,
k
,
&
lda
,
&
ldb
,
&
ldc
);
//gemm_switch(state, transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount);
gemm_switch_fp32accum
(
state
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
a
,
lda
,
strideA
,
b
,
ldb
,
strideB
,
beta
,
c
,
ldc
,
strideC
,
batchCount
);
}
at
::
Tensor
strided_batched_gemm_cuda
(
float
beta
,
at
::
Tensor
in_result
,
float
alpha
,
at
::
Tensor
batch1
,
at
::
Tensor
batch2
)
{
bool
transpose_result
;
char
transpose_batch1
,
transpose_batch2
;
int64_t
lda
,
ldb
,
ldc
;
at
::
Tensor
result
,
input1
,
input2
;
if
(
in_result
.
stride
(
1
)
==
1
)
{
transpose_result
=
false
;
result
=
in_result
;
ldc
=
result
.
stride
(
2
);
}
else
if
(
in_result
.
stride
(
2
)
==
1
)
{
transpose_result
=
true
;
at
::
Tensor
swap
=
batch2
;
batch2
=
batch1
;
batch1
=
swap
;
result
=
in_result
;
ldc
=
result
.
stride
(
1
);
}
else
{
AT_ASSERTM
(
false
,
"result should be contiguous"
);
}
if
(
batch1
.
stride
(
transpose_result
?
2
:
1
)
==
1
&&
batch1
.
stride
(
transpose_result
?
1
:
2
)
!=
0
)
{
transpose_batch1
=
'n'
;
input1
=
batch1
;
lda
=
input1
.
stride
(
transpose_result
?
1
:
2
);
}
else
if
(
batch1
.
stride
(
transpose_result
?
1
:
2
)
==
1
&&
batch1
.
stride
(
transpose_result
?
2
:
1
)
!=
0
)
{
transpose_batch1
=
't'
;
input1
=
batch1
;
lda
=
input1
.
stride
(
transpose_result
?
2
:
1
);
}
else
{
AT_ASSERTM
(
false
,
"input1 should be contiguous"
);
}
if
(
batch2
.
stride
(
transpose_result
?
2
:
1
)
==
1
&&
batch2
.
stride
(
transpose_result
?
1
:
2
)
!=
0
)
{
transpose_batch2
=
'n'
;
input2
=
batch2
;
ldb
=
input2
.
stride
(
transpose_result
?
1
:
2
);
}
else
if
(
batch2
.
stride
(
transpose_result
?
1
:
2
)
==
1
&&
batch2
.
stride
(
transpose_result
?
2
:
1
)
!=
0
)
{
transpose_batch2
=
't'
;
input2
=
batch2
;
ldb
=
input2
.
stride
(
transpose_result
?
2
:
1
);
}
else
{
AT_ASSERTM
(
false
,
"input2 should be contiguous"
);
}
int64_t
num_batches
=
result
.
size
(
0
);
HgemmStridedBatched
(
state
,
transpose_batch1
,
transpose_batch2
,
result
.
size
(
transpose_result
?
2
:
1
),
result
.
size
(
transpose_result
?
1
:
2
),
input1
.
size
(
transpose_result
?
1
:
2
),
alpha
,
static_cast
<
const
half
*>
(
input1
.
data_ptr
()),
lda
,
input1
.
stride
(
0
),
static_cast
<
const
half
*>
(
input2
.
data_ptr
()),
ldb
,
input2
.
stride
(
0
),
beta
,
static_cast
<
half
*>
(
result
.
data_ptr
()),
ldc
,
result
.
stride
(
0
),
num_batches
);
return
in_result
;
}
PyTorch/NLP/Transformer/fairseq/optim/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_optimizer
import
FairseqOptimizer
OPTIMIZER_REGISTRY
=
{}
OPTIMIZER_CLASS_NAMES
=
set
()
def
build_optimizer
(
args
,
params
):
params
=
filter
(
lambda
p
:
p
.
requires_grad
,
params
)
return
OPTIMIZER_REGISTRY
[
args
.
optimizer
](
args
,
params
)
def
register_optimizer
(
name
):
"""Decorator to register a new optimizer."""
def
register_optimizer_cls
(
cls
):
if
name
in
OPTIMIZER_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate optimizer ({})'
.
format
(
name
))
if
not
issubclass
(
cls
,
FairseqOptimizer
):
raise
ValueError
(
'Optimizer ({}: {}) must extend FairseqOptimizer'
.
format
(
name
,
cls
.
__name__
))
if
cls
.
__name__
in
OPTIMIZER_CLASS_NAMES
:
# We use the optimizer class name as a unique identifier in
# checkpoints, so all optimizer must have unique class names.
raise
ValueError
(
'Cannot register optimizer with duplicate class name ({})'
.
format
(
cls
.
__name__
))
OPTIMIZER_REGISTRY
[
name
]
=
cls
OPTIMIZER_CLASS_NAMES
.
add
(
cls
.
__name__
)
return
cls
return
register_optimizer_cls
# automatically import any Python files in the optim/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.optim.'
+
module
)
PyTorch/NLP/Transformer/fairseq/optim/adam.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.
import
FairseqOptimizer
,
register_optimizer
from
apex.optimizers.fused_adam
import
FusedAdam
@
register_optimizer
(
'adam'
)
class
FairseqAdam
(
FairseqOptimizer
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
(
args
,
params
)
self
.
_optimizer
=
FusedAdam
(
params
,
**
self
.
optimizer_config
)
@
staticmethod
def
add_args
(
parser
):
"""Add optimizer-specific arguments to the parser."""
parser
.
add_argument
(
'--adam-betas'
,
default
=
(
0.9
,
0.999
),
nargs
=
2
,
type
=
float
,
metavar
=
'B1 B2'
,
help
=
'betas for Adam optimizer'
)
parser
.
add_argument
(
'--adam-eps'
,
type
=
float
,
default
=
1e-8
,
metavar
=
'D'
,
help
=
'epsilon for Adam optimizer'
)
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return
{
'lr'
:
self
.
args
.
lr
[
0
],
'betas'
:
self
.
args
.
adam_betas
,
'eps'
:
self
.
args
.
adam_eps
,
'weight_decay'
:
self
.
args
.
weight_decay
,
}
PyTorch/NLP/Transformer/fairseq/optim/fairseq_optimizer.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch.optim
class
FairseqOptimizer
(
object
):
def
__init__
(
self
,
args
,
params
):
super
().
__init__
()
self
.
args
=
args
self
.
params
=
params
@
staticmethod
def
add_args
(
parser
):
"""Add optimizer-specific arguments to the parser."""
pass
@
property
def
optimizer
(
self
):
"""Return a torch.optim.optimizer.Optimizer instance."""
if
not
hasattr
(
self
,
'_optimizer'
):
raise
NotImplementedError
if
not
isinstance
(
self
.
_optimizer
,
torch
.
optim
.
Optimizer
):
raise
ValueError
(
'_optimizer must be an instance of torch.optim.Optimizer'
)
return
self
.
_optimizer
@
property
def
optimizer_config
(
self
):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
raise
NotImplementedError
def
get_lr
(
self
):
"""Return the current learning rate."""
return
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
def
set_lr
(
self
,
lr
):
"""Set the learning rate."""
for
param_group
in
self
.
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr
def
state_dict
(
self
):
"""Return the optimizer's state dict."""
return
self
.
optimizer
.
state_dict
()
def
load_state_dict
(
self
,
state_dict
):
"""Load an optimizer state dict.
In general we should prefer the configuration of the existing optimizer
instance (e.g., learning rate) over that found in the state_dict. This
allows us to resume training from a checkpoint using a new set of
optimizer args.
"""
self
.
optimizer
.
load_state_dict
(
state_dict
)
# override learning rate, momentum, etc. with latest values
for
group
in
self
.
optimizer
.
param_groups
:
group
.
update
(
self
.
optimizer_config
)
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step."""
return
self
.
optimizer
.
step
(
closure
)
def
zero_grad
(
self
):
"""Clears the gradients of all optimized parameters."""
for
group
in
self
.
optimizer
.
param_groups
:
for
p
in
group
[
'params'
]:
p
.
grad
=
None
return
self
.
optimizer
.
zero_grad
()
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/__init__.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
importlib
import
os
from
.fairseq_lr_scheduler
import
FairseqLRScheduler
LR_SCHEDULER_REGISTRY
=
{}
def
build_lr_scheduler
(
args
,
optimizer
):
return
LR_SCHEDULER_REGISTRY
[
args
.
lr_scheduler
](
args
,
optimizer
)
def
register_lr_scheduler
(
name
):
"""Decorator to register a new LR scheduler."""
def
register_lr_scheduler_cls
(
cls
):
if
name
in
LR_SCHEDULER_REGISTRY
:
raise
ValueError
(
'Cannot register duplicate LR scheduler ({})'
.
format
(
name
))
if
not
issubclass
(
cls
,
FairseqLRScheduler
):
raise
ValueError
(
'LR Scheduler ({}: {}) must extend FairseqLRScheduler'
.
format
(
name
,
cls
.
__name__
))
LR_SCHEDULER_REGISTRY
[
name
]
=
cls
return
cls
return
register_lr_scheduler_cls
# automatically import any Python files in the optim/lr_scheduler/ directory
for
file
in
os
.
listdir
(
os
.
path
.
dirname
(
__file__
)):
if
file
.
endswith
(
'.py'
)
and
not
file
.
startswith
(
'_'
):
module
=
file
[:
file
.
find
(
'.py'
)]
importlib
.
import_module
(
'fairseq.optim.lr_scheduler.'
+
module
)
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/fixed_schedule.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'fixed'
)
class
FixedSchedule
(
FairseqLRScheduler
):
"""Decay the LR on a fixed schedule."""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
# set defaults
args
.
warmup_updates
=
getattr
(
args
,
'warmup_updates'
,
0
)
or
0
self
.
lr
=
args
.
lr
[
0
]
if
args
.
warmup_updates
>
0
:
self
.
warmup_factor
=
1.
/
args
.
warmup_updates
else
:
self
.
warmup_factor
=
1
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--force-anneal'
,
'--fa'
,
type
=
int
,
metavar
=
'N'
,
help
=
'force annealing at specified epoch'
)
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
def
get_next_lr
(
self
,
epoch
):
lrs
=
self
.
args
.
lr
if
self
.
args
.
force_anneal
is
None
or
epoch
<
self
.
args
.
force_anneal
:
# use fixed LR schedule
next_lr
=
lrs
[
min
(
epoch
,
len
(
lrs
)
-
1
)]
else
:
# annneal based on lr_shrink
next_lr
=
lrs
[
-
1
]
*
self
.
args
.
lr_shrink
**
(
epoch
+
1
-
self
.
args
.
force_anneal
)
return
next_lr
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
self
.
lr
=
self
.
get_next_lr
(
epoch
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
self
.
args
.
warmup_updates
>
0
and
num_updates
<=
self
.
args
.
warmup_updates
:
self
.
warmup_factor
=
num_updates
/
float
(
self
.
args
.
warmup_updates
)
self
.
optimizer
.
set_lr
(
self
.
warmup_factor
*
self
.
lr
)
return
self
.
optimizer
.
get_lr
()
PyTorch/NLP/Transformer/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
torch.optim.lr_scheduler
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'reduce_lr_on_plateau'
)
class
ReduceLROnPlateau
(
FairseqLRScheduler
):
"""Decay the LR by a factor every time the validation loss plateaus."""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
'Cannot use a fixed learning rate schedule with reduce_lr_on_plateau.'
' Consider --lr-scheduler=fixed instead.'
)
self
.
lr_scheduler
=
torch
.
optim
.
lr_scheduler
.
ReduceLROnPlateau
(
self
.
optimizer
.
optimizer
,
patience
=
0
,
factor
=
args
.
lr_shrink
)
def
state_dict
(
self
):
"""Return the LR scheduler state dict."""
return
{
'best'
:
self
.
lr_scheduler
.
best
,
'last_epoch'
:
self
.
lr_scheduler
.
last_epoch
,
}
def
load_state_dict
(
self
,
state_dict
):
"""Load an LR scheduler state dict."""
self
.
lr_scheduler
.
best
=
state_dict
[
'best'
]
if
'last_epoch'
in
state_dict
:
self
.
lr_scheduler
.
last_epoch
=
state_dict
[
'last_epoch'
]
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
if
val_loss
is
not
None
:
self
.
lr_scheduler
.
step
(
val_loss
,
epoch
)
else
:
self
.
lr_scheduler
.
last_epoch
=
epoch
return
self
.
optimizer
.
get_lr
()
PyTorch/NLP/Transformer/fairseq/options.py
deleted
100644 → 0
View file @
c056df78
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
#-------------------------------------------------------------------------
#
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
argparse
import
os
import
torch
from
fairseq.models
import
ARCH_MODEL_REGISTRY
,
ARCH_CONFIG_REGISTRY
from
fairseq.criterions
import
CRITERION_REGISTRY
from
fairseq.optim
import
OPTIMIZER_REGISTRY
from
fairseq.optim.lr_scheduler
import
LR_SCHEDULER_REGISTRY
def
get_training_parser
():
parser
=
get_parser
(
'Trainer'
)
add_dataset_args
(
parser
,
train
=
True
,
gen
=
True
)
add_distributed_training_args
(
parser
)
add_model_args
(
parser
)
add_optimization_args
(
parser
)
add_checkpoint_args
(
parser
)
add_inference_args
(
parser
)
add_perf_args
(
parser
)
return
parser
def
get_inference_parser
():
parser
=
get_parser
(
'Generation'
)
add_dataset_args
(
parser
,
gen
=
True
)
add_inference_args
(
parser
)
add_perf_args
(
parser
)
return
parser
def
parse_args_and_arch
(
parser
,
input_args
=
None
,
parse_known
=
False
):
# The parser doesn't know about model/criterion/optimizer-specific args, so
# we parse twice. First we parse the model/criterion/optimizer, then we
# parse a second time after adding the *-specific arguments.
# If input_args is given, we will parse those args instead of sys.argv.
args
,
_
=
parser
.
parse_known_args
(
input_args
)
# Add model-specific args to parser.
if
hasattr
(
args
,
'arch'
):
model_specific_group
=
parser
.
add_argument_group
(
'Model-specific configuration'
,
# Only include attributes which are explicitly given as command-line
# arguments or which have default values.
argument_default
=
argparse
.
SUPPRESS
,
)
ARCH_MODEL_REGISTRY
[
args
.
arch
].
add_args
(
model_specific_group
)
# Add *-specific args to parser.
if
hasattr
(
args
,
'optimizer'
):
OPTIMIZER_REGISTRY
[
args
.
optimizer
].
add_args
(
parser
)
if
hasattr
(
args
,
'lr_scheduler'
):
LR_SCHEDULER_REGISTRY
[
args
.
lr_scheduler
].
add_args
(
parser
)
# Parse a second time.
if
parse_known
:
args
,
extra
=
parser
.
parse_known_args
(
input_args
)
else
:
args
=
parser
.
parse_args
(
input_args
)
extra
=
None
# Post-process args.
if
hasattr
(
args
,
'max_sentences_valid'
)
and
args
.
max_sentences_valid
is
None
:
args
.
max_sentences_valid
=
args
.
max_sentences
args
.
max_positions
=
(
args
.
max_source_positions
,
args
.
max_target_positions
)
if
hasattr
(
args
,
'target_bleu'
)
and
(
args
.
online_eval
or
args
.
target_bleu
)
and
not
args
.
remove_bpe
:
args
.
remove_bpe
=
'@@ '
# Apply architecture configuration.
if
hasattr
(
args
,
'arch'
):
ARCH_CONFIG_REGISTRY
[
args
.
arch
](
args
)
if
parse_known
:
return
args
,
extra
else
:
return
args
def
get_parser
(
desc
):
parser
=
argparse
.
ArgumentParser
(
description
=
'Facebook AI Research Sequence-to-Sequence Toolkit -- '
+
desc
)
parser
.
add_argument
(
'--log-interval'
,
type
=
int
,
default
=
500
,
metavar
=
'N'
,
help
=
'print aggregated stats and flush json log every N iteration'
)
parser
.
add_argument
(
'--seed'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'pseudo random number generator seed'
)
parser
.
add_argument
(
'--amp'
,
action
=
'store_true'
,
help
=
'use Automatic Mixed Precision'
)
parser
.
add_argument
(
'--stat-file'
,
type
=
str
,
default
=
'run_log.json'
,
help
=
'Name of the file containing DLLogger output'
)
parser
.
add_argument
(
'--save-dir'
,
metavar
=
'DIR'
,
default
=
'results'
,
help
=
'path to save checkpoints and logs'
)
parser
.
add_argument
(
'--do-sanity-check'
,
action
=
'store_true'
,
help
=
'Perform evaluation on test set before running the training'
)
return
parser
def
add_dataset_args
(
parser
,
train
=
False
,
gen
=
False
):
group
=
parser
.
add_argument_group
(
'Dataset and data loading'
)
group
.
add_argument
(
'--max-tokens'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of tokens in a batch'
)
group
.
add_argument
(
'--max-sentences'
,
'--batch-size'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a batch'
)
parser
.
add_argument
(
'-s'
,
'--source-lang'
,
default
=
None
,
metavar
=
'SRC'
,
help
=
'source language'
)
parser
.
add_argument
(
'-t'
,
'--target-lang'
,
default
=
None
,
metavar
=
'TARGET'
,
help
=
'target language'
)
parser
.
add_argument
(
'--raw-text'
,
action
=
'store_true'
,
help
=
'load raw text dataset'
)
parser
.
add_argument
(
'--left-pad-source'
,
default
=
True
,
type
=
bool
,
metavar
=
'BOOL'
,
help
=
'pad the source on the left (default: True)'
)
parser
.
add_argument
(
'--left-pad-target'
,
default
=
False
,
type
=
bool
,
metavar
=
'BOOL'
,
help
=
'pad the target on the left (default: False)'
)
parser
.
add_argument
(
'--max-source-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the source sequence'
)
parser
.
add_argument
(
'--max-target-positions'
,
default
=
1024
,
type
=
int
,
metavar
=
'N'
,
help
=
'max number of tokens in the target sequence'
)
parser
.
add_argument
(
'--pad-sequence'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'Pad sequences to a multiple of N'
)
if
train
:
parser
.
add_argument
(
'data'
,
metavar
=
'DIR'
,
help
=
'path to data directory'
)
group
.
add_argument
(
'--train-subset'
,
default
=
'train'
,
metavar
=
'SPLIT'
,
choices
=
[
'train'
,
'valid'
,
'test'
],
help
=
'data subset to use for training (train, valid, test)'
)
group
.
add_argument
(
'--valid-subset'
,
default
=
'valid'
,
metavar
=
'SPLIT'
,
help
=
'comma separated list of data subsets to use for validation'
' (train, valid, valid1, test, test1)'
)
group
.
add_argument
(
'--max-sentences-valid'
,
type
=
int
,
metavar
=
'N'
,
help
=
'maximum number of sentences in a validation batch'
' (defaults to --max-sentences)'
)
if
gen
:
group
.
add_argument
(
'--gen-subset'
,
default
=
'test'
,
metavar
=
'SPLIT'
,
help
=
'data subset to generate (train, valid, test)'
)
group
.
add_argument
(
'--num-shards'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'shard generation over N shards'
)
group
.
add_argument
(
'--shard-id'
,
default
=
0
,
type
=
int
,
metavar
=
'ID'
,
help
=
'id of the shard to generate (id < num_shards)'
)
return
group
def
add_distributed_training_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Distributed training'
)
group
.
add_argument
(
'--distributed-world-size'
,
type
=
int
,
metavar
=
'N'
,
default
=
torch
.
cuda
.
device_count
(),
help
=
'total number of GPUs across all nodes (default: all visible GPUs)'
)
group
.
add_argument
(
'--distributed-rank'
,
default
=
os
.
getenv
(
'LOCAL_RANK'
,
0
),
type
=
int
,
help
=
'rank of the current worker'
)
group
.
add_argument
(
'--local_rank'
,
default
=
0
,
type
=
int
,
help
=
'rank of the current worker'
)
group
.
add_argument
(
'--distributed-backend'
,
default
=
'nccl'
,
type
=
str
,
help
=
'distributed backend'
)
group
.
add_argument
(
'--distributed-init-method'
,
default
=
None
,
type
=
str
,
help
=
'typically tcp://hostname:port that will be used to '
'establish initial connetion'
)
group
.
add_argument
(
'--distributed-port'
,
default
=-
1
,
type
=
int
,
help
=
'port number (not required if using --distributed-init-method)'
)
group
.
add_argument
(
'--device-id'
,
default
=
0
,
type
=
int
,
help
=
'which GPU to use (usually configured automatically)'
)
return
group
def
add_optimization_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Optimization'
)
group
.
add_argument
(
'--max-epoch'
,
'--me'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified epoch'
)
group
.
add_argument
(
'--max-update'
,
'--mu'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'force stop training at specified update'
)
group
.
add_argument
(
'--target-bleu'
,
default
=
0.0
,
type
=
float
,
metavar
=
'TARGET'
,
help
=
'force stop training after reaching target bleu'
)
group
.
add_argument
(
'--clip-norm'
,
default
=
25
,
type
=
float
,
metavar
=
'NORM'
,
help
=
'clip threshold of gradients'
)
group
.
add_argument
(
'--update-freq'
,
default
=
[
1
],
nargs
=
'+'
,
type
=
int
,
help
=
'update parameters every N_i batches, when in epoch i'
)
# Optimizer definitions can be found under fairseq/optim/
group
.
add_argument
(
'--optimizer'
,
default
=
'nag'
,
metavar
=
'OPT'
,
choices
=
OPTIMIZER_REGISTRY
.
keys
(),
help
=
'optimizer: {} (default: nag)'
.
format
(
', '
.
join
(
OPTIMIZER_REGISTRY
.
keys
())))
group
.
add_argument
(
'--lr'
,
'--learning-rate'
,
default
=
[
0.25
],
nargs
=
'+'
,
type
=
float
,
help
=
'learning rate for the first N epochs; all epochs >N using LR_N'
' (note: this may be interpreted differently depending on --lr-scheduler)'
)
group
.
add_argument
(
'--momentum'
,
default
=
0.99
,
type
=
float
,
metavar
=
'M'
,
help
=
'momentum factor'
)
group
.
add_argument
(
'--weight-decay'
,
'--wd'
,
default
=
0.0
,
type
=
float
,
metavar
=
'WD'
,
help
=
'weight decay'
)
# Learning rate schedulers can be found under fairseq/optim/lr_scheduler/
group
.
add_argument
(
'--lr-scheduler'
,
default
=
'reduce_lr_on_plateau'
,
help
=
'learning rate scheduler: {} (default: reduce_lr_on_plateau)'
.
format
(
', '
.
join
(
LR_SCHEDULER_REGISTRY
.
keys
())))
group
.
add_argument
(
'--lr-shrink'
,
default
=
0.1
,
type
=
float
,
metavar
=
'LS'
,
help
=
'learning rate shrink factor for annealing, lr_new = (lr * lr_shrink)'
)
group
.
add_argument
(
'--min-lr'
,
default
=
1e-5
,
type
=
float
,
metavar
=
'LR'
,
help
=
'minimum learning rate'
)
# Criterion args
parser
.
add_argument
(
'--label-smoothing'
,
default
=
0.
,
type
=
float
,
metavar
=
'D'
,
help
=
'epsilon for label smoothing, 0 means no label smoothing'
)
return
group
def
add_checkpoint_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Checkpointing'
)
group
.
add_argument
(
'--restore-file'
,
default
=
'checkpoint_last.pt'
,
help
=
'filename in save-dir from which to load checkpoint'
)
group
.
add_argument
(
'--save-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'save a checkpoint every N epochs'
)
group
.
add_argument
(
'--no-save'
,
action
=
'store_true'
,
help
=
'don
\'
t save models or checkpoints'
)
group
.
add_argument
(
'--no-epoch-checkpoints'
,
action
=
'store_true'
,
help
=
'only store last and best checkpoints'
)
group
.
add_argument
(
'--validate-interval'
,
type
=
int
,
default
=
1
,
metavar
=
'N'
,
help
=
'validate every N epochs'
)
return
group
def
add_common_eval_args
(
group
):
group
.
add_argument
(
'--path'
,
metavar
=
'FILE'
,
help
=
'path(s) to model file(s), colon separated'
)
group
.
add_argument
(
'--file'
,
metavar
=
'FILE'
,
default
=
None
,
type
=
str
,
help
=
'path to a file with input data for inference'
)
group
.
add_argument
(
'--remove-bpe'
,
nargs
=
'?'
,
const
=
'@@ '
,
default
=
None
,
help
=
'remove BPE tokens before scoring'
)
group
.
add_argument
(
'--cpu'
,
action
=
'store_true'
,
help
=
'generate on CPU'
)
group
.
add_argument
(
'--quiet'
,
action
=
'store_true'
,
help
=
'only print final scores'
)
def
add_inference_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Generation'
)
add_common_eval_args
(
group
)
group
.
add_argument
(
'--beam'
,
default
=
4
,
type
=
int
,
metavar
=
'N'
,
help
=
'beam size'
)
group
.
add_argument
(
'--nbest'
,
default
=
1
,
type
=
int
,
metavar
=
'N'
,
help
=
'number of hypotheses to output'
)
group
.
add_argument
(
'--max-len-a'
,
default
=
0
,
type
=
float
,
metavar
=
'N'
,
help
=
(
'generate sequences of maximum length ax + b, '
'where x is the source length'
))
group
.
add_argument
(
'--max-len-b'
,
default
=
200
,
type
=
int
,
metavar
=
'N'
,
help
=
(
'generate sequences of maximum length ax + b, '
'where x is the source length'
))
group
.
add_argument
(
'--min-len'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
(
'minimum generation length'
))
group
.
add_argument
(
'--no-early-stop'
,
action
=
'store_true'
,
help
=
(
'continue searching even after finalizing k=beam '
'hypotheses; this is more correct, but increases '
'generation time by 50%%'
))
group
.
add_argument
(
'--unnormalized'
,
action
=
'store_true'
,
help
=
'compare unnormalized hypothesis scores'
)
group
.
add_argument
(
'--no-beamable-mm'
,
action
=
'store_true'
,
help
=
'don
\'
t use BeamableMM in attention layers'
)
group
.
add_argument
(
'--lenpen'
,
default
=
1
,
type
=
float
,
help
=
'length penalty: <1.0 favors shorter, >1.0 favors longer sentences'
)
group
.
add_argument
(
'--unkpen'
,
default
=
0
,
type
=
float
,
help
=
'unknown word penalty: <0 produces more unks, >0 produces fewer'
)
group
.
add_argument
(
'--replace-unk'
,
nargs
=
'?'
,
const
=
True
,
default
=
None
,
help
=
'perform unknown replacement (optionally with alignment dictionary)'
)
group
.
add_argument
(
'--prefix-size'
,
default
=
0
,
type
=
int
,
metavar
=
'PS'
,
help
=
'initialize generation by target prefix of given length'
)
group
.
add_argument
(
'--sampling'
,
action
=
'store_true'
,
help
=
'sample hypotheses instead of using beam search'
)
group
.
add_argument
(
'--sampling-topk'
,
default
=-
1
,
type
=
int
,
metavar
=
'PS'
,
help
=
'sample from top K likely next words instead of all words'
)
group
.
add_argument
(
'--sampling-temperature'
,
default
=
1
,
type
=
float
,
metavar
=
'N'
,
help
=
'temperature for random sampling'
)
group
.
add_argument
(
'--print-alignment'
,
action
=
'store_true'
,
help
=
'if set, uses attention feedback to compute and print alignment to source tokens'
)
group
.
add_argument
(
'--online-eval'
,
action
=
'store_true'
,
help
=
'score model at the end of epoch'
)
group
.
add_argument
(
'--save-predictions'
,
action
=
'store_true'
,
help
=
'Save predictions produced with online evaluation'
)
group
.
add_argument
(
'--test-cased-bleu'
,
action
=
'store_true'
,
help
=
'Use cased bleu for online eval'
)
group
.
add_argument
(
'--bpe-codes'
,
default
=
None
,
type
=
str
,
metavar
=
'CODES'
,
help
=
'file with bpe codes'
)
group
.
add_argument
(
'--buffer-size'
,
default
=
64
,
type
=
int
,
metavar
=
'N'
,
help
=
'read this many sentences into a buffer before processing them'
)
group
.
add_argument
(
'--fp16'
,
action
=
'store_true'
,
help
=
'use fp16 precision'
)
return
group
def
add_model_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Model configuration'
)
# Model definitions can be found under fairseq/models/
#
# The model architecture can be specified in several ways.
# In increasing order of priority:
# 1) model defaults (lowest priority)
# 2) --arch argument
group
.
add_argument
(
'--arch'
,
'-a'
,
default
=
'fconv'
,
metavar
=
'ARCH'
,
required
=
True
,
choices
=
ARCH_MODEL_REGISTRY
.
keys
(),
help
=
'model architecture: {} (default: fconv)'
.
format
(
', '
.
join
(
ARCH_MODEL_REGISTRY
.
keys
())),
)
# Criterion definitions can be found under fairseq/criterions/
group
.
add_argument
(
'--criterion'
,
default
=
'cross_entropy'
,
metavar
=
'CRIT'
,
choices
=
CRITERION_REGISTRY
.
keys
(),
help
=
'training criterion: {} (default: cross_entropy)'
.
format
(
', '
.
join
(
CRITERION_REGISTRY
.
keys
())),
)
return
group
def
add_perf_args
(
parser
):
group
=
parser
.
add_argument_group
(
'Performance'
)
group
.
add_argument
(
'--fuse-dropout-add'
,
action
=
'store_true'
,
help
=
'Fuse dropout and residual adds.'
)
group
.
add_argument
(
'--fuse-relu-dropout'
,
action
=
'store_true'
,
help
=
'Fuse Relu and Dropout.'
)
group
.
add_argument
(
'--fuse-layer-norm'
,
action
=
'store_true'
,
help
=
'Use APEX
\'
s FusedLayerNorm instead of torch.nn.LayerNorm'
)
return
group
Prev
1
2
3
4
5
6
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment