Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
17d897e0
Commit
17d897e0
authored
Jan 29, 2021
by
Mostofa Patwary
Browse files
WIP: main_retriver_merge
parent
b69bc7ef
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
185 additions
and
42 deletions
+185
-42
megatron/arguments.py
megatron/arguments.py
+17
-6
megatron/data/ict_dataset.py
megatron/data/ict_dataset.py
+19
-3
megatron/learning_rates.py
megatron/learning_rates.py
+7
-1
megatron/model/language_model.py
megatron/model/language_model.py
+13
-0
megatron/model/transformer.py
megatron/model/transformer.py
+4
-1
megatron/training.py
megatron/training.py
+18
-4
megatron/utils.py
megatron/utils.py
+37
-1
pretrain_ict.py
pretrain_ict.py
+70
-26
No files found.
megatron/arguments.py
View file @
17d897e0
...
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -39,7 +39,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_validation_args
(
parser
)
parser
=
_add_validation_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_data_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_autoresume_args
(
parser
)
parser
=
_add_
realm
_args
(
parser
)
parser
=
_add_
biencoder
_args
(
parser
)
# Custom arguments.
# Custom arguments.
if
extra_args_provider
is
not
None
:
if
extra_args_provider
is
not
None
:
...
@@ -310,6 +310,8 @@ def _add_training_args(parser):
...
@@ -310,6 +310,8 @@ def _add_training_args(parser):
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
group
.
add_argument
(
'--checkpoint-activations'
,
action
=
'store_true'
,
help
=
'Checkpoint activation to allow for training '
help
=
'Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.'
)
'with larger models, sequences, and batch sizes.'
)
group
.
add_argument
(
'--override-checkpoint-version'
,
type
=
float
,
default
=
None
,
help
=
'Override checkpoint version'
)
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
group
.
add_argument
(
'--distribute-checkpointed-activations'
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
'If set, distribute checkpointed activations '
help
=
'If set, distribute checkpointed activations '
...
@@ -567,12 +569,19 @@ def _add_autoresume_args(parser):
...
@@ -567,12 +569,19 @@ def _add_autoresume_args(parser):
return
parser
return
parser
def
_add_
realm
_args
(
parser
):
def
_add_
biencoder
_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'
realm
'
)
group
=
parser
.
add_argument_group
(
title
=
'
biencoder
'
)
# network size
# network size
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
group
.
add_argument
(
'--ict-head-size'
,
type
=
int
,
default
=
None
,
help
=
'Size of block embeddings to be used in ICT and REALM (paper default: 128)'
)
help
=
'Size of block embeddings to be used in ICT and REALM (paper default: 128)'
)
group
.
add_argument
(
'--projection-dim'
,
type
=
int
,
default
=
0
,
help
=
'Size of projection head used in biencoder (paper default: 128)'
)
group
.
add_argument
(
'--shared-query-context-model'
,
action
=
'store_true'
,
help
=
'Whether to share the parameters of the query and context models or not'
)
group
.
add_argument
(
'--pool-type'
,
type
=
str
,
default
=
'cls-token'
,
choices
=
[
'avg'
,
'cls-token'
,
'max'
],
help
=
'different options are: avg | cls-token | max, default=cls-token'
)
# checkpointing
# checkpointing
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
'--ict-load'
,
type
=
str
,
default
=
None
,
...
@@ -589,14 +598,16 @@ def _add_realm_args(parser):
...
@@ -589,14 +598,16 @@ def _add_realm_args(parser):
help
=
'Whether to use one sentence documents in ICT'
)
help
=
'Whether to use one sentence documents in ICT'
)
# training
# training
group
.
add_argument
(
'--report-topk-accuracies'
,
nargs
=
'+'
,
default
=
[],
group
.
add_argument
(
'--report-topk-accuracies'
,
nargs
=
'+'
,
type
=
int
,
default
=
[],
help
=
"Which top-k accuracies to report (e.g. '1 5 20')"
)
help
=
"Which top-k accuracies to report (e.g. '1 5 20')"
)
group
.
add_argument
(
'--retriever-score-scaling'
,
action
=
'store_true'
,
help
=
"Whether to scale retriever scores by inverse square root of hidden size"
)
# faiss index
# faiss index
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
,
help
=
'Whether create the FaissMIPSIndex on GPU'
)
help
=
'Whether create the FaissMIPSIndex on GPU'
)
group
.
add_argument
(
'--block-data-path'
,
type
=
str
,
default
=
None
,
#
group.add_argument('--block-data-path', type=str, default=None,
help
=
'Where to save/load BlockData to/from'
)
#
help='Where to save/load BlockData to/from')
# indexer
# indexer
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
group
.
add_argument
(
'--indexer-batch-size'
,
type
=
int
,
default
=
128
,
...
...
megatron/data/ict_dataset.py
View file @
17d897e0
...
@@ -9,6 +9,16 @@ from megatron import get_args
...
@@ -9,6 +9,16 @@ from megatron import get_args
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
from
megatron.data.realm_dataset_utils
import
get_block_samples_mapping
def
make_attention_mask
(
source_block
,
target_block
):
"""
Returns a 2-dimensional (2-D) attention mask
:param source_block: 1-D array
:param target_block: 1-D array
"""
mask
=
(
target_block
[
None
,
:]
>=
1
)
*
(
source_block
[:,
None
]
>=
1
)
mask
=
mask
.
astype
(
np
.
int64
)
# (source_length, target_length)
return
mask
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
def
get_ict_dataset
(
use_titles
=
True
,
query_in_block_prob
=
1
):
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
"""Get a dataset which uses block samples mappings to get ICT/block indexing data (via get_block())
...
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
...
@@ -93,14 +103,20 @@ class ICTDataset(Dataset):
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
context_tokens
,
context_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
query_mask
=
make_attention_mask
(
query_tokens
,
query_tokens
)
context_mask
=
make_attention_mask
(
context_tokens
,
context_tokens
)
block_data
=
sample_data
.
as_array
()
block_data
=
sample_data
.
as_array
()
sample
=
{
sample
=
{
'query_tokens'
:
query_tokens
,
'query_tokens'
:
query_tokens
,
'query_mask'
:
query_mask
,
'query_pad_mask'
:
query_pad_mask
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'context_tokens'
:
context_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'context_mask'
:
context_mask
,
'context_pad_mask'
:
context_pad_mask
,
'block_data'
:
block_data
,
'block_data'
:
block_data
,
}
}
...
...
megatron/learning_rates.py
View file @
17d897e0
...
@@ -59,6 +59,12 @@ class AnnealingLR(object):
...
@@ -59,6 +59,12 @@ class AnnealingLR(object):
"""Learning rate decay functions from:
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
#print_rank_0(self.warmup_steps)
#print_rank_0(self.num_steps)
#print_rank_0(self.warmup_steps)
#print_rank_0(self.max_lr)
#print_rank_0(self.max_lr * float(self.num_steps) / float(self.warmup_steps))
# Use linear warmup for the initial part.
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
...
@@ -97,7 +103,7 @@ class AnnealingLR(object):
...
@@ -97,7 +103,7 @@ class AnnealingLR(object):
new_lr
=
self
.
get_lr
()
new_lr
=
self
.
get_lr
()
for
group
in
self
.
optimizer
.
param_groups
:
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
group
[
'lr'
]
=
new_lr
#print_rank_0(new_lr)
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
...
...
megatron/model/language_model.py
View file @
17d897e0
...
@@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule):
...
@@ -374,6 +374,19 @@ class TransformerLanguageModelBase(MegatronModule):
# Transformer.
# Transformer.
if
self
.
_transformer_key
in
state_dict
:
if
self
.
_transformer_key
in
state_dict
:
state_dict_
=
state_dict
[
self
.
_transformer_key
]
state_dict_
=
state_dict
[
self
.
_transformer_key
]
# for compatiability with t5 architecture
# this is temporary unless t5_main is merged
elif
'encoder'
in
state_dict
:
state_dict_
=
state_dict
[
'encoder'
]
# for forward compatibility for t5 architecture
state_dict_attention
=
{}
for
key
in
state_dict_
.
keys
():
if
'.self_attention.'
in
key
:
state_dict_attention
[
key
.
replace
(
".self_attention."
,
".attention."
)]
=
state_dict_
[
key
]
else
:
state_dict_attention
[
key
]
=
state_dict_
[
key
]
state_dict_
=
state_dict_attention
else
:
else
:
# for backward compatibility.
# for backward compatibility.
state_dict_
=
{}
state_dict_
=
{}
...
...
megatron/model/transformer.py
View file @
17d897e0
...
@@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule):
...
@@ -214,6 +214,9 @@ class ParallelSelfAttention(MegatronModule):
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
checkpoint_version
=
get_checkpoint_version
()
checkpoint_version
=
get_checkpoint_version
()
if
get_args
().
override_checkpoint_version
is
not
None
:
checkpoint_version
=
get_args
().
override_checkpoint_version
if
checkpoint_version
is
not
None
:
if
checkpoint_version
is
not
None
:
if
checkpoint_version
==
0
:
if
checkpoint_version
==
0
:
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
# [s, b, (3 * np * hn)] --> [s, b, (np * 3 * hn)]
...
@@ -472,7 +475,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -472,7 +475,7 @@ class ParallelTransformerLayer(MegatronModule):
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
# Second residual connection.
# Second residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
if
self
.
apply_residual_connection_post_layernorm
:
residual
=
layernorm_output
residual
=
layernorm_output
...
...
megatron/training.py
View file @
17d897e0
...
@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization
...
@@ -48,7 +48,7 @@ from megatron.model import get_params_for_weight_decay_optimization
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.data.data_loaders
import
build_pretraining_data_loader
from
megatron.data.data_loaders
import
build_pretraining_data_loader
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
,
params_grad_norm
,
params_global_norm
,
print_model
,
print_grads
def
print_datetime
(
string
):
def
print_datetime
(
string
):
...
@@ -663,11 +663,25 @@ def train_step(forward_step_func, data_iterator,
...
@@ -663,11 +663,25 @@ def train_step(forward_step_func, data_iterator,
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
optimizer
.
clip_master_grads
(
args
.
clip_grad
)
timers
(
'backward-clip-grad'
).
stop
()
timers
(
'backward-clip-grad'
).
stop
()
#print_rank_0("after backward")
#print_grads(model)
print_model
(
model
)
print_rank_0
(
params_global_norm
(
model
))
print_rank_0
(
params_grad_norm
(
model
))
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
optimizer
.
step
()
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
#print_rank_0("after optimizer")
#print_model(model)
print_rank_0
(
params_global_norm
(
model
))
#print_rank_0(params_grad_norm(model))
#sys.exit()
# Update learning rate.
# Update learning rate.
skipped_iter
=
0
skipped_iter
=
0
if
not
(
args
.
fp16
and
optimizer
.
overflow
):
if
not
(
args
.
fp16
and
optimizer
.
overflow
):
...
@@ -905,9 +919,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -905,9 +919,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
# Exiting based on iterations
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
not
saved_checkpoint
:
#
if not saved_checkpoint:
save_checkpoint_and_time
(
iteration
,
model
,
optimizer
,
#
save_checkpoint_and_time(iteration, model, optimizer,
lr_scheduler
)
#
lr_scheduler)
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
print_datetime
(
'exiting program at iteration {}'
.
format
(
iteration
))
sys
.
exit
()
sys
.
exit
()
...
...
megatron/utils.py
View file @
17d897e0
...
@@ -150,4 +150,40 @@ def get_ltor_masks_and_position_ids(data,
...
@@ -150,4 +150,40 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
return
attention_mask
,
loss_mask
,
position_ids
def
params_grad_norm
(
model
):
print_rank_0
(
"params_grad_norm"
)
norm2
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
model
.
parameters
():
if
param
.
grad
is
None
:
continue
norm
=
torch
.
norm
(
param
.
grad
.
data
.
float
(),
2
)
norm2
+=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm2
)
norm
=
norm2
**
0.5
return
norm
.
item
()
def
params_global_norm
(
model
):
print_rank_0
(
"params_global_norm"
)
norm2
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
model
.
parameters
():
norm
=
torch
.
norm
(
param
.
data
.
float
(),
2
)
norm2
+=
norm
*
norm
torch
.
distributed
.
all_reduce
(
norm2
)
norm
=
norm2
**
0.5
return
norm
.
item
()
def
print_model
(
model
):
print_rank_0
(
"print-model"
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
requires_grad
:
#print("{} {}".format(name, param.data), flush=True)
print_rank_0
(
"{} {}"
.
format
(
name
,
param
.
data
))
return
def
print_grads
(
model
):
print_rank_0
(
"print-grads"
)
for
name
,
param
in
model
.
named_parameters
():
if
param
.
grad
is
None
:
continue
print_rank_0
(
"{} {}"
.
format
(
name
,
param
.
grad
))
pretrain_ict.py
View file @
17d897e0
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
# limitations under the License.
# limitations under the License.
"""Pretrain BERT for Inverse Cloze Task"""
"""Pretrain BERT for Inverse Cloze Task"""
import
sys
import
math
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -26,14 +28,16 @@ from megatron import mpu
...
@@ -26,14 +28,16 @@ from megatron import mpu
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.data.dataset_utils
import
build_train_valid_test_datasets
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.model.
realm
_model
import
general_ict
_model_provider
from
megatron.model.
biencoder
_model
import
biencoder
_model_provider
from
megatron.data.
realm
_dataset_utils
import
get_ict_batch
from
megatron.data.
biencoder
_dataset_utils
import
get_ict_batch
def
pretrain_ict_model_provider
():
def
pretrain_ict_model_provider
():
args
=
get_args
()
args
=
get_args
()
return
general_ict_model_provider
(
False
,
False
)
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
shared_query_context_model
=
args
.
shared_query_context_model
)
return
model
def
get_group_world_size_rank
():
def
get_group_world_size_rank
():
...
@@ -72,7 +76,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
...
@@ -72,7 +76,6 @@ class AllgatherFromDataParallelRegion(torch.autograd.Function):
output
=
output_list
[
rank
].
contiguous
()
output
=
output_list
[
rank
].
contiguous
()
return
output
return
output
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
def
forward_step
(
data_iterator
,
model
,
input_tensor
):
"""Forward step."""
"""Forward step."""
args
=
get_args
()
args
=
get_args
()
...
@@ -80,37 +83,76 @@ def forward_step(data_iterator, model, input_tensor):
...
@@ -80,37 +83,76 @@ def forward_step(data_iterator, model, input_tensor):
# Get the batch.
# Get the batch.
timers
(
'batch-generator'
).
start
()
timers
(
'batch-generator'
).
start
()
query_tokens
,
query_
pad_
mask
,
\
query_tokens
,
query_mask
,
\
block
_tokens
,
block_pad_mask
,
block
_indices
=
get_ict_batch
(
data_iterator
)
context
_tokens
,
context_mask
,
context
_indices
=
get_ict_batch
(
data_iterator
)
timers
(
'batch-generator'
).
stop
()
timers
(
'batch-generator'
).
stop
()
# Query and Context Types
query_types
=
torch
.
cuda
.
LongTensor
(
*
query_tokens
.
shape
).
fill_
(
0
)
context_types
=
torch
.
cuda
.
LongTensor
(
*
context_tokens
.
shape
).
fill_
(
0
)
# Forward model.
#print_rank_0(query_tokens)
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
#print_rank_0(context_tokens)
micro_batch_size
=
query_logits
.
shape
[
0
]
#print_rank_0(torch.sum(query_types))
global_batch_size
=
dist
.
get_world_size
()
*
micro_batch_size
# recall we assert that tensor_model_parallel_size == 1
#print_rank_0(torch.sum(query_mask))
#print_rank_0(torch.sum(context_types))
#print_rank_0(torch.sum(context_mask))
all_query_logits
=
AllgatherFromDataParallelRegion
.
apply
(
query_logits
)
#print_rank_0(params_global_norm(model))
all_block_logits
=
AllgatherFromDataParallelRegion
.
apply
(
block_logits
)
#print_rank_0(params_grad_norm(model))
# Forward model.
query_logits
,
context_logits
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
#print_rank_0(query_logits)
#print_rank_0(context_logits)
# scores are inner products between query and block embeddings
micro_batch_size
=
query_logits
.
shape
[
0
]
retrieval_scores
=
all_query_logits
.
float
().
matmul
(
torch
.
transpose
(
all_block_logits
,
0
,
1
).
float
())
# recall we assert that tensor_model_parallel_size == 1
softmaxed
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
#global_batch_size = dist.get_world_size() * micro_batch_size
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmaxed
,
k
=
softmaxed
.
shape
[
1
],
sorted
=
True
)
#all_query_logits = AllgatherFromDataParallelRegion.apply(query_logits)
#all_context_logits = AllgatherFromDataParallelRegion.apply(context_logits)
global_batch_size
=
micro_batch_size
all_query_logits
=
query_logits
all_context_logits
=
context_logits
# scores are inner products between query and context embeddings
retrieval_scores
=
torch
.
matmul
(
all_query_logits
,
torch
.
transpose
(
all_context_logits
,
0
,
1
))
# scaling the retriever scores
if
args
.
retriever_score_scaling
:
retrieval_scores
=
retrieval_scores
/
math
.
sqrt
(
args
.
hidden_size
)
softmax_scores
=
F
.
log_softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmax_scores
,
k
=
softmax_scores
.
shape
[
1
],
sorted
=
True
)
def
topk_accuracy
(
k
):
def
topk_accuracy
(
k
):
return
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted_indices
[
i
,
:
k
])
for
i
in
range
(
global_batch_size
)])
/
global_batch_size
])
return
torch
.
cuda
.
FloatTensor
([
sum
([
int
(
i
in
sorted_indices
[
i
,
:
k
])
\
for
i
in
range
(
global_batch_size
)])
/
global_batch_size
])
topk_accs
=
[
topk_accuracy
(
int
(
k
))
for
k
in
args
.
report_topk_accuracies
]
topk_accs
=
[
topk_accuracy
(
int
(
k
))
for
k
in
args
.
report_topk_accuracies
]
retrieval_loss
=
torch
.
nn
.
CrossEntropyLoss
()(
retrieval_scores
,
torch
.
arange
(
global_batch_size
).
long
().
cuda
())
retrieval_loss
=
retrieval_loss
.
float
()
averaged_losses
=
average_losses_across_data_parallel_group
([
retrieval_loss
,
*
topk_accs
])
# create stats_dict with retrieval loss and all specified top-k accuracies
labels
=
torch
.
arange
(
global_batch_size
).
long
().
cuda
()
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
for
k
,
v
in
zip
(
args
.
report_topk_accuracies
,
averaged_losses
[
1
:])}
loss
=
F
.
nll_loss
(
softmax_scores
,
labels
,
reduction
=
'mean'
)
stats_dict
=
dict
(
retrieval_loss
=
averaged_losses
[
0
],
**
topk_acc_dict
)
reduced_losses
=
average_losses_across_data_parallel_group
([
loss
,
*
topk_accs
])
# Scale the retrieval loss
loss
=
loss
*
mpu
.
get_data_parallel_world_size
()
return
retrieval_loss
,
stats_dict
#retrieval_loss = torch.nn.CrossEntropyLoss()(retrieval_scores, torch.arange(global_batch_size).long().cuda())
#retrieval_loss = retrieval_loss.float()
#averaged_losses = average_losses_across_data_parallel_group([retrieval_loss, *topk_accs])
# create stats_dict with retrieval loss and all specified top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
zip
(
args
.
report_topk_accuracies
,
reduced_losses
[
1
:])}
stats_dict
=
dict
(
loss
=
reduced_losses
[
0
],
**
topk_acc_dict
)
#print_rank_0(loss)
#print_rank_0(stats_dict)
#sys.exit()
return
loss
,
stats_dict
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
...
@@ -136,5 +178,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -136,5 +178,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
pretrain_ict_model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment