Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b9fcb7b4
Commit
b9fcb7b4
authored
Apr 28, 2021
by
mpatwary
Browse files
adding dpr code
parent
957d1c9a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
751 additions
and
0 deletions
+751
-0
tasks/orqa/supervised/data.py
tasks/orqa/supervised/data.py
+301
-0
tasks/orqa/supervised/eval_utils.py
tasks/orqa/supervised/eval_utils.py
+211
-0
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+239
-0
No files found.
tasks/orqa/supervised/data.py
0 → 100644
View file @
b9fcb7b4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA dataset."""
import
json
import
random
from
abc
import
ABC
from
abc
import
abstractmethod
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
from
megatron.data.biencoder_dataset_utils
import
make_history_mask
def
build_token_types_from_context_list
(
ctx_list
,
tokenizer
,
max_seq_length
):
ctx_id_list
,
ctx_types_list
=
[],
[]
for
context
in
ctx_list
:
title_ids
=
tokenizer
.
tokenize
(
context
[
'title'
])
ctx_ids
=
tokenizer
.
tokenize
(
context
[
'text'
])
ctx_ids
=
title_ids
+
[
tokenizer
.
sep_id
]
+
ctx_ids
ctx_ids
,
ctx_types
,
_
=
build_tokens_types_paddings_from_ids
(
ctx_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
ctx_id_list
.
append
(
ctx_ids
)
ctx_types_list
.
append
(
ctx_types
)
return
ctx_id_list
,
ctx_types_list
def
build_tokens_types_paddings_from_text
(
query
,
context
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
query_ids
=
tokenizer
.
tokenize
(
query
)
query_ids
,
query_types
,
query_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
query_ids
,
max_seq_length
,
\
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
# Appending the title of the context at front
extended_ctx_ids
=
None
if
context
is
not
None
:
title_ids
=
tokenizer
.
tokenize
(
context
[
'title'
])
ctx_ids
=
tokenizer
.
tokenize
(
context
[
'text'
])
extended_ctx_ids
=
title_ids
+
[
tokenizer
.
sep
]
+
ctx_ids
ctx_ids
,
ctx_types
,
ctx_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
extended_ctx_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
return
query_ids
,
query_types
,
query_pad_mask
,
\
ctx_ids
,
ctx_types
,
ctx_pad_mask
# Similar code tasks/data_utils with some changes
def
build_tokens_types_paddings_from_ids
(
text_ids
,
max_seq_length
,
cls_id
,
sep_id
,
pad_id
):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
text_ids
)
enc_ids
.
extend
(
text_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
pad_mask
=
([
1
]
*
num_tokens_enc
)
+
([
0
]
*
padding_length
)
pad_mask
=
np
.
array
(
pad_mask
,
dtype
=
np
.
int64
)
return
enc_ids
,
tokentypes_enc
,
pad_mask
def
build_sample
(
query_ids
,
query_types
,
query_pad_mask
,
ctx_ids
,
ctx_types
,
ctx_pad_mask
,
answers
,
neg_ctx_id_list
=
None
,
neg_ctx_types_list
=
None
,
include_neg
=
False
):
"""Convert to numpy and return a sample consumed by the batch producer."""
query_ids
=
np
.
array
(
query_ids
,
dtype
=
np
.
int64
)
query_types
=
np
.
array
(
query_types
,
dtype
=
np
.
int64
)
query_mask
=
make_attention_mask
(
query_ids
,
query_ids
)
ctx_ids
=
np
.
array
(
ctx_ids
,
dtype
=
np
.
int64
)
ctx_types
=
np
.
array
(
ctx_types
,
dtype
=
np
.
int64
)
ctx_mask
=
make_attention_mask
(
ctx_ids
,
ctx_ids
)
sample
=
({
'query'
:
query_ids
,
'query_mask'
:
query_mask
,
'query_types'
:
query_types
,
'query_pad_mask'
:
query_pad_mask
,
'context'
:
ctx_ids
,
'context_mask'
:
ctx_mask
,
'context_types'
:
ctx_types
,
'context_pad_mask'
:
ctx_pad_mask
,
'reference'
:
answers
})
if
include_neg
:
neg_ctx_ids
=
np
.
array
(
neg_ctx_id_list
,
dtype
=
np
.
int64
)
neg_ctx_id_types
=
np
.
array
(
neg_ctx_types_list
,
dtype
=
np
.
int64
)
neg_ctx_mask
=
np
.
array
([
make_attention_mask
(
ids
,
ids
)
\
for
ids
in
neg_ctx_ids
],
dtype
=
np
.
int64
)
sample
[
'neg_context'
]
=
neg_ctx_ids
sample
[
'neg_context_types'
]
=
neg_ctx_id_types
sample
[
'neg_context_mask'
]
=
neg_ctx_mask
return
sample
class
OpenRetrievalAbstractDataset
(
ABC
,
Dataset
):
"""Open Retrieval base dataset class."""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapaths
,
tokenizer
,
\
max_seq_length
,
evaluate
=
False
):
# Store inputs.
args
=
get_args
()
self
.
evaluate
=
evaluate
self
.
val_av_rank_hard_neg
=
args
.
val_av_rank_hard_neg
self
.
val_av_rank_other_neg
=
args
.
val_av_rank_other_neg
self
.
train_with_neg
=
args
.
train_with_neg
self
.
train_hard_neg
=
args
.
train_hard_neg
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
# Process the files.
string
=
' > paths:'
for
path
in
datapaths
:
string
+=
' '
+
path
print_rank_0
(
string
)
self
.
samples
=
[]
for
datapath
in
datapaths
:
self
.
samples
.
extend
(
self
.
process_samples_from_single_path
(
datapath
))
args
=
get_args
()
if
args
.
sample_rate
<
1
:
# subsample
k
=
int
(
len
(
self
.
samples
)
*
args
.
sample_rate
)
self
.
samples
=
random
.
sample
(
self
.
samples
,
k
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
raw_sample
=
self
.
samples
[
idx
]
query_ids
,
query_types
,
query_pad_mask
,
ctx_ids
,
ctx_types
,
\
ctx_pad_mask
=
build_tokens_types_paddings_from_text
(
\
raw_sample
[
'question'
],
raw_sample
[
'pos_context'
],
\
self
.
tokenizer
,
self
.
max_seq_length
)
if
self
.
evaluate
:
neg_ctx_list
=
\
raw_sample
[
'negative_context'
][:
self
.
val_av_rank_other_neg
]
+
\
raw_sample
[
'hard_negative_context'
][:
self
.
val_av_rank_hard_neg
]
neg_ctx_id_list
,
neg_ctx_types_list
=
\
build_token_types_from_context_list
(
neg_ctx_list
,
\
self
.
tokenizer
,
self
.
max_seq_length
)
elif
self
.
train_with_neg
:
hard_negative_ctx
=
raw_sample
[
'hard_negative_context'
]
negative_ctx
=
raw_sample
[
'negative_context'
]
if
True
:
# TODO: fix this or remove this condition
random
.
shuffle
(
hard_negative_ctx
)
random
.
shuffle
(
negative_ctx
)
neg_ctx_list
=
hard_negative_ctx
[:
self
.
train_hard_neg
]
# In the Google NQ dataset by DPR paper, there are around more than
# 50 missing hard negatives in training data.
# In those cases, substitute hard negatives by simple negatives.
if
len
(
neg_ctx_list
)
<
self
.
train_hard_neg
:
neg_ctx_list
+=
negative_ctx
[:
self
.
train_hard_neg
-
\
len
(
neg_ctx_list
)]
neg_ctx_id_list
,
neg_ctx_types_list
=
\
build_token_types_from_context_list
(
neg_ctx_list
,
self
.
tokenizer
,
self
.
max_seq_length
)
else
:
neg_ctx_id_list
=
None
neg_ctx_types_list
=
None
sample
=
build_sample
(
query_ids
,
query_types
,
query_pad_mask
,
ctx_ids
,
ctx_types
,
ctx_pad_mask
,
raw_sample
[
'answers'
],
neg_ctx_id_list
,
neg_ctx_types_list
,
include_neg
=
self
.
evaluate
or
self
.
train_with_neg
)
return
sample
@
staticmethod
@
abstractmethod
def
process_samples_from_single_path
(
filename
):
"""Abstract method that takes a filename and
returns a list of dataset samples, each sample being a dict of
{'text': string, 'text': string}
"""
pass
def
normalize_question
(
question
):
if
question
[
-
1
]
==
'?'
:
question
=
question
[:
-
1
]
return
question
class
NQSupervisedDataset
(
OpenRetrievalAbstractDataset
):
def
__init__
(
self
,
name
,
datapaths
,
tokenizer
,
max_seq_length
,
\
evaluate
=
False
):
super
().
__init__
(
'natural_questions_ret'
,
name
,
datapaths
,
tokenizer
,
max_seq_length
,
evaluate
=
evaluate
)
@
staticmethod
def
process_samples_from_single_path
(
filename
):
""""Implement abstract method."""
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
samples
=
[]
total
=
0
with
open
(
filename
,
'r'
,
encoding
=
"utf-8"
)
as
f
:
data
=
json
.
load
(
f
)
for
row
in
data
:
question
=
normalize_question
(
row
[
'question'
])
pos_context
=
row
[
'positive_ctxs'
][
0
]
# Hard Negative Contexts
if
len
(
row
[
'hard_negative_ctxs'
])
>
0
:
hard_neg_context
=
row
[
'hard_negative_ctxs'
]
else
:
hard_neg_context
=
[]
# Negative Contexts
if
len
(
row
[
'negative_ctxs'
])
>
0
:
neg_context
=
row
[
'negative_ctxs'
]
else
:
neg_context
=
[]
answers
=
row
[
'answers'
]
sample
=
{
'question'
:
question
,
'pos_context'
:
pos_context
,
'hard_negative_context'
:
hard_neg_context
,
'negative_context'
:
neg_context
,
'answers'
:
answers
}
total
+=
1
samples
.
append
(
sample
)
if
total
%
5000
==
0
:
print_rank_0
(
' > processed {} so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
samples
)))
return
samples
tasks/orqa/supervised/eval_utils.py
0 → 100644
View file @
b9fcb7b4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
from
collections
import
OrderedDict
import
math
import
numpy
as
np
import
time
import
torch
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.finetune_utils
import
build_data_loader
def
task_collate_fn
(
batch_data
):
# generate batch
batch_size
=
len
(
batch_data
)
tensorized
=
OrderedDict
()
for
d
in
batch_data
:
for
k
,
v
in
d
.
items
():
tensorized
.
setdefault
(
k
,
[]).
append
(
v
)
# assert len(tensorized) == 12
tensorized
[
'query'
]
=
torch
.
LongTensor
(
tensorized
[
'query'
])
tensorized
[
'query_mask'
]
=
torch
.
LongTensor
(
tensorized
[
'query_mask'
])
tensorized
[
'query_types'
]
=
torch
.
LongTensor
(
tensorized
[
'query_types'
])
tensorized
[
'query_pad_mask'
]
=
\
torch
.
LongTensor
(
tensorized
[
'query_pad_mask'
])
tensorized
[
'context'
]
=
torch
.
LongTensor
(
tensorized
[
'context'
])
tensorized
[
'context_mask'
]
=
\
torch
.
LongTensor
(
tensorized
[
'context_mask'
])
tensorized
[
'context_types'
]
=
\
torch
.
LongTensor
(
tensorized
[
'context_types'
])
tensorized
[
'context_pad_mask'
]
=
\
torch
.
LongTensor
(
tensorized
[
'context_pad_mask'
])
if
'neg_context'
in
tensorized
:
tensorized
[
'neg_context'
]
=
\
torch
.
LongTensor
(
np
.
concatenate
(
tensorized
[
'neg_context'
]))
tensorized
[
'neg_context_mask'
]
=
\
torch
.
LongTensor
(
np
.
concatenate
(
tensorized
[
'neg_context_mask'
]))
tensorized
[
'neg_context_types'
]
=
\
torch
.
LongTensor
(
np
.
concatenate
(
tensorized
[
'neg_context_types'
]))
return
tensorized
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
query_tokens
=
batch
[
'query'
].
long
().
cuda
()
query_mask
=
(
batch
[
'query_mask'
]
<
0.5
).
cuda
()
query_types
=
batch
[
'query_types'
].
long
().
cuda
()
query_pad_mask
=
batch
[
'query_pad_mask'
].
long
().
cuda
()
context_tokens
=
batch
[
'context'
].
long
().
cuda
()
context_mask
=
(
batch
[
'context_mask'
]
<
0.5
).
cuda
()
context_types
=
batch
[
'context_types'
].
long
().
cuda
()
context_pad_mask
=
batch
[
'context_pad_mask'
].
long
().
cuda
()
if
'neg_context'
in
batch
:
neg_context_tokens
=
batch
[
'neg_context'
].
long
().
cuda
()
neg_context_mask
=
(
batch
[
'neg_context_mask'
]
<
0.5
).
cuda
()
neg_context_types
=
batch
[
'neg_context_types'
].
long
().
cuda
()
else
:
neg_context_tokens
=
None
neg_context_mask
=
None
neg_context_types
=
None
reference
=
batch
[
'reference'
]
return
query_tokens
,
query_mask
,
query_types
,
query_pad_mask
,
\
context_tokens
,
context_mask
,
context_types
,
context_pad_mask
,
\
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
reference
def
accuracy_func_provider
(
single_dataset_provider
,
rank0sampler
=
False
):
#, datapath,
# rank0sampler=False):
"""Provide function that calculates accuracies."""
args
=
get_args
()
print_rank_0
(
"accuracy_func_provider is CALLED"
)
# Build dataloaders
datapath
=
args
.
valid_data
dataset
=
single_dataset_provider
(
datapath
)
drop_last
=
False
if
mpu
.
get_data_parallel_world_size
()
>
1
and
not
rank0sampler
:
drop_last
=
True
print_rank_0
(
datapath
)
print_rank_0
(
rank0sampler
)
dataloader
=
build_data_loader
(
dataset
,
args
.
eval_micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
drop_last
,
task_collate_fn
=
task_collate_fn
)
#shuffle=False,
#rank0sampler=rank0sampler)
dataloaders
=
(
dataset
.
dataset_name
,
dataloader
)
def
metrics_func
(
model
,
epoch
,
output_predictions
=
False
):
print_rank_0
(
'calculating metrics by accuracy func in ORQA...'
)
if
output_predictions
:
assert
rank0sampler
names
=
'predictions'
name
,
dataloader
=
dataloaders
if
args
.
task
==
"RET-FINETUNE-NQ"
:
start_time
=
time
.
time
()
output
=
retrieval_loss
(
model
,
dataloader
)
stats_dict
,
total
=
output
format_string
=
""
for
k
,
v
in
stats_dict
.
items
():
format_string
+=
"|{} = {:.2f}"
.
format
(
k
,
v
/
total
)
print_rank_0
(
"epoch:{}{}"
.
format
(
epoch
,
format_string
))
print_rank_0
(
"taken time to calcuate metrics {:.3f}"
.
format
(
\
time
.
time
()
-
start_time
))
else
:
raise
AssertionError
(
"{} Task not supported"
.
format
(
args
.
task
))
return
metrics_func
def
retrieval_loss
(
model
,
dataloader
):
args
=
get_args
()
total
=
0
topk_stats_dict
=
{
'top{}_acc'
.
format
(
k
):
0
for
k
in
\
args
.
retriever_report_topk_accuracies
}
stats_dict
=
dict
(
rank
=
0
,
**
topk_stats_dict
)
assert
len
(
model
)
==
1
unwrapped_model
=
model
[
0
]
unwrapped_model
.
eval
()
with
torch
.
no_grad
():
# For all the batches in the dataset.
for
batch
in
dataloader
:
# Run the model forward.
query_tokens
,
query_mask
,
query_types
,
_
,
\
context_tokens
,
context_mask
,
context_types
,
_
,
\
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
\
reference
=
process_batch
(
batch
)
query_logits
,
context_logits
=
unwrapped_model
(
query_tokens
,
query_mask
,
query_types
,
torch
.
cat
([
context_tokens
,
neg_context_tokens
]),
torch
.
cat
([
context_mask
,
neg_context_mask
]),
torch
.
cat
([
context_types
,
neg_context_types
]))
retrieval_scores
=
torch
.
matmul
(
query_logits
,
torch
.
transpose
(
context_logits
,
0
,
1
))
if
args
.
retriever_score_scaling
:
retrieval_scores
=
retrieval_scores
/
\
math
.
sqrt
(
args
.
hidden_size
)
local_batch_size
=
query_logits
.
shape
[
0
]
labels
=
torch
.
arange
(
local_batch_size
).
long
().
cuda
()
softmax_scores
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmax_scores
,
k
=
softmax_scores
.
shape
[
1
],
sorted
=
True
)
def
topk_accuracy
(
k
):
return
torch
.
cuda
.
FloatTensor
(
[
sum
([
int
(
labels
[
i
]
in
sorted_indices
[
i
,
:
k
])
for
i
in
\
range
(
local_batch_size
)])])
def
get_rank
():
return
torch
.
cuda
.
FloatTensor
(
[
sum
([
torch
.
nonzero
(
labels
[
i
]
==
sorted_indices
[
i
])[
0
][
0
]
\
for
i
in
range
(
local_batch_size
)])])
topk_accs
=
[
topk_accuracy
(
k
)
for
k
in
\
args
.
retriever_report_topk_accuracies
]
rank
=
get_rank
()
losses
=
average_losses_across_data_parallel_group
([
rank
,
\
*
topk_accs
])
# create stats_dict with retrieval loss and all specified
# top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
zip
(
args
.
retriever_report_topk_accuracies
,
losses
[
1
:])}
temp_stats_dict
=
dict
(
rank
=
losses
[
0
],
**
topk_acc_dict
)
for
k
in
stats_dict
.
keys
():
stats_dict
[
k
]
+=
temp_stats_dict
[
k
]
total
+=
local_batch_size
unwrapped_model
.
train
()
return
stats_dict
,
total
tasks/orqa/supervised/finetune.py
0 → 100644
View file @
b9fcb7b4
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA finetuning/evaluation."""
from
functools
import
partial
import
math
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.model.biencoder_model
import
biencoder_model_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import accuracy_func_provider
#from tasks.t5_model_utils.finetune_utils_open_retrieval import finetune
from
pretrain_ict
import
get_group_world_size_rank
from
tasks.finetune_utils
import
finetune
from
tasks.orqa.supervised.eval_utils
import
accuracy_func_provider
from
tasks.orqa.supervised.eval_utils
import
process_batch
,
task_collate_fn
def
orqa
(
Dataset
):
# , name_from_datapath_func):
def
cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
args
=
get_args
()
timers
=
get_timers
()
tokenizer
=
get_tokenizer
()
# Get the batch.
timers
(
'batch generator'
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
query_tokens
,
query_mask
,
query_types
,
query_pad_mask
,
\
context_tokens
,
context_mask
,
context_types
,
context_pad_mask
,
\
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
\
reference
=
process_batch
(
batch_
)
timers
(
'batch generator'
).
stop
()
local_batch_size
=
query_tokens
.
shape
[
0
]
# Text representation of query and context
query_list
,
context_list
=
[],
[]
for
i
in
range
(
local_batch_size
):
query_list
.
append
(
tokenizer
.
decode
(
query_tokens
[
i
].
tolist
()))
context_list
.
append
(
tokenizer
.
decode
(
context_tokens
[
i
].
tolist
()))
if
neg_context_tokens
is
not
None
:
context_tokens
=
torch
.
cat
([
context_tokens
,
neg_context_tokens
])
context_mask
=
torch
.
cat
([
context_mask
,
neg_context_mask
])
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
# Forward model.
#query_logits, context_logits = model(query_tokens, query_mask,
output_tensor
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
return
output_tensor
,
partial
(
cross_entropy_loss_func_
,
query_tokens
,
context_tokens
)
#def cross_entropy_loss_func(labels, output_tensor):
def
cross_entropy_loss_func_
(
query_tokens
,
context_tokens
,
output_tensor
):
args
=
get_args
()
local_batch_size
=
query_tokens
.
shape
[
0
]
group
,
rank
,
world_size
=
get_group_world_size_rank
()
# recall we assert that model_parallel_size == 1
global_batch_size
=
world_size
*
local_batch_size
query_logits
,
context_logits
=
output_tensor
if
world_size
>
1
:
input_
=
torch
.
empty_like
(
context_logits
).
copy_
(
\
context_logits
).
detach_
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
].
copy_
(
input_
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# Check if all-gather happens in order
assert
tensor_list
[
rank
].
sum
().
item
()
==
\
context_logits
.
sum
().
item
()
# Preserves the gradient
tensor_list
[
rank
]
=
context_logits
all_context_logits
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
# Query tensors
input_
=
torch
.
empty_like
(
query_logits
).
copy_
(
\
query_logits
).
detach_
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
].
copy_
(
input_
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# Check if all-gather happens in order
assert
tensor_list
[
rank
].
sum
().
item
()
==
query_logits
.
sum
().
item
()
# Preserves the gradient
tensor_list
[
rank
]
=
query_logits
all_query_logits
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
else
:
all_query_logits
=
query_logits
all_context_logits
=
context_logits
retrieval_scores
=
torch
.
matmul
(
all_query_logits
,
torch
.
transpose
(
all_context_logits
,
0
,
1
))
# Scaling the retrieval scores
if
args
.
retriever_score_scaling
:
retrieval_scores
=
retrieval_scores
/
math
.
sqrt
(
args
.
hidden_size
)
if
args
.
train_with_neg
:
# if the world size is 3, local batch size is 4, and
# local context size is 8, what we want is
# labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
labels
=
[]
local_context_size
=
context_tokens
.
shape
[
0
]
for
i
in
range
(
world_size
):
j
=
i
*
local_context_size
labels
.
extend
(
list
(
range
(
j
,
j
+
local_batch_size
)))
labels
=
torch
.
LongTensor
(
labels
).
cuda
()
assert
len
(
labels
)
==
global_batch_size
else
:
labels
=
torch
.
arange
(
global_batch_size
).
long
().
cuda
()
# Cross-entropy loss.
softmax_scores
=
F
.
log_softmax
(
retrieval_scores
,
dim
=
1
)
loss
=
F
.
nll_loss
(
softmax_scores
,
labels
,
reduction
=
'mean'
)
max_score
,
max_idxs
=
torch
.
max
(
softmax_scores
,
1
)
correct_predictions_count
=
(
max_idxs
==
labels
).
sum
().
float
()
# Reduce loss for logging.
reduced_loss
=
average_losses_across_data_parallel_group
([
loss
,
\
correct_predictions_count
])
# Loss scaling for correct losses in Supervised Retrieval
loss
=
loss
*
mpu
.
get_data_parallel_world_size
()
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
],
'correct_prediction_count'
:
reduced_loss
[
1
]}
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
train_dataset
=
Dataset
(
'training'
,
args
.
train_data
,
tokenizer
,
args
.
retriever_seq_length
,
evaluate
=
False
)
valid_dataset
=
Dataset
(
'validation'
,
args
.
valid_data
,
tokenizer
,
args
.
retriever_seq_length
,
evaluate
=
True
)
return
train_dataset
,
valid_dataset
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building retriever model for {} ...'
.
format
(
args
.
task
))
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
single_dataset_provider
(
datapath
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
#name = name_from_datapath_func(datapath)
name
=
datapath
[
0
].
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
return
Dataset
(
name
,
datapath
,
tokenizer
,
args
.
retriever_seq_length
,
evaluate
=
True
)
#def distributed_metrics_func_provider():
def
metrics_func_provider
():
"""Provide metrics callback function."""
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
return
accuracy_func_provider
(
single_dataset_provider
)
#def rank0_metrics_func_provider(datapath):
# """Provide metrics callback function."""
# return accuracy_func_provider(single_dataset_provider, datapath,
# rank0sampler=True)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
metrics_func_provider
,
task_collate_fn
=
task_collate_fn
)
#,end_of_training_callback_provider=rank0_metrics_func_provider)
def
main
():
args
=
get_args
()
if
args
.
task
==
'RET-FINETUNE-NQ'
:
from
tasks.orqa.supervised.data
import
NQSupervisedDataset
as
Dataset
#def name_from_datapath(datapath):
# return datapath[0].split('/')[-1].split('.')[0]
else
:
raise
NotImplementedError
(
'ORQA task {} is not implemented.'
.
format
(
args
.
task
))
orqa
(
Dataset
)
#, name_from_datapath)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment