Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b9bd1a11
Commit
b9bd1a11
authored
Jul 13, 2020
by
Neel Kant
Browse files
Additional refactoring
parent
ca0cdfaa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
69 additions
and
162 deletions
+69
-162
indexer.py
indexer.py
+11
-13
megatron/data/dataset_utils.py
megatron/data/dataset_utils.py
+1
-1
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+0
-115
megatron/data/realm_index.py
megatron/data/realm_index.py
+52
-21
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+5
-5
megatron/training.py
megatron/training.py
+0
-7
No files found.
indexer.py
View file @
b9bd1a11
import
os
import
sys
import
time
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
,
get_adlr_autoresume
,
print_rank_0
from
megatron
import
get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.checkpointing
import
get_checkpoint_tracker_filename
,
get_checkpoint_name
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.dataset_utils
import
get_indexed_dataset_
from
megatron.data.realm_dataset
import
ICTDataset
from
megatron.data.ict_dataset
import
ICTDataset
from
megatron.data.realm_dataset_utils
import
BlockSampleData
from
megatron.data.realm_index
import
detach
,
BlockData
from
megatron.data.realm_index
import
detach
,
BlockData
,
FaissMIPSIndex
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.data.samplers
import
DistributedBatchSampler
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.training
import
get_model
from
megatron.training
import
get_model
from
pretrain_
bert_
ict
import
get_batch
,
general_ict_model_provider
from
pretrain_ict
import
get_batch
,
general_ict_model_provider
def
pprint
(
*
args
):
def
pprint
(
*
args
):
...
@@ -30,17 +25,21 @@ class IndexBuilder(object):
...
@@ -30,17 +25,21 @@ class IndexBuilder(object):
self
.
model
=
None
self
.
model
=
None
self
.
dataloader
=
None
self
.
dataloader
=
None
self
.
block_data
=
None
self
.
block_data
=
None
# need to know whether we're using a REALM checkpoint (args.load) or ICT checkpoint
assert
not
(
args
.
load
and
args
.
ict_load
)
self
.
using_realm_chkpt
=
args
.
ict_load
is
None
self
.
load_attributes
()
self
.
load_attributes
()
self
.
is_main_builder
=
args
.
rank
==
0
self
.
is_main_builder
=
args
.
rank
==
0
self
.
iteration
=
self
.
total_processed
=
0
self
.
iteration
=
self
.
total_processed
=
0
def
load_attributes
(
self
):
def
load_attributes
(
self
):
"""Load the necessary attributes: model, dataloader and empty BlockData"""
"""Load the necessary attributes: model, dataloader and empty BlockData"""
# TODO: handle from_realm_chkpt correctly
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
from_realm_chkpt
=
self
.
using_realm_chkpt
)
self
.
model
=
load_ict_checkpoint
(
only_block_model
=
True
,
from_realm_chkpt
=
False
)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
dataloader
=
iter
(
get_one_epoch_dataloader
(
get_ict_dataset
()))
self
.
block_data
=
BlockData
()
self
.
block_data
=
BlockData
(
load_from_path
=
False
)
def
track_and_report_progress
(
self
,
batch_size
):
def
track_and_report_progress
(
self
,
batch_size
):
"""Utility function for tracking progress"""
"""Utility function for tracking progress"""
...
@@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
...
@@ -141,7 +140,6 @@ def get_ict_dataset(use_titles=True, query_in_block_prob=1):
num_epochs
=
1
,
num_epochs
=
1
,
max_num_samples
=
None
,
max_num_samples
=
None
,
max_seq_length
=
args
.
seq_length
,
max_seq_length
=
args
.
seq_length
,
short_seq_prob
=
0.0001
,
# doesn't matter
seed
=
1
,
seed
=
1
,
query_in_block_prob
=
query_in_block_prob
,
query_in_block_prob
=
query_in_block_prob
,
use_titles
=
use_titles
,
use_titles
=
use_titles
,
...
...
megatron/data/dataset_utils.py
View file @
b9bd1a11
...
@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -417,7 +417,6 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
num_epochs
=
None
,
num_epochs
=
None
,
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_num_samples
=
train_valid_test_num_samples
[
index
],
max_seq_length
=
max_seq_length
,
max_seq_length
=
max_seq_length
,
short_seq_prob
=
short_seq_prob
,
seed
=
seed
seed
=
seed
)
)
...
@@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
...
@@ -434,6 +433,7 @@ def build_train_valid_test_datasets(data_prefix, data_impl, splits_string,
dataset
=
BertDataset
(
dataset
=
BertDataset
(
indexed_dataset
=
indexed_dataset
,
indexed_dataset
=
indexed_dataset
,
masked_lm_prob
=
masked_lm_prob
,
masked_lm_prob
=
masked_lm_prob
,
short_seq_prob
=
short_seq_prob
,
**
kwargs
**
kwargs
)
)
...
...
megatron/data/realm_dataset.py
deleted
100644 → 0
View file @
ca0cdfaa
import
collections
import
itertools
import
random
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
get_tokenizer
from
megatron.data.realm_dataset_utils
import
BlockSampleData
,
get_block_samples_mapping
,
join_str_list
class
ICTDataset
(
Dataset
):
"""Dataset containing sentences and their blocks for an inverse cloze task."""
def
__init__
(
self
,
name
,
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
query_in_block_prob
,
short_seq_prob
,
seed
,
use_titles
=
True
,
use_one_sent_docs
=
False
):
self
.
name
=
name
self
.
seed
=
seed
self
.
max_seq_length
=
max_seq_length
self
.
query_in_block_prob
=
query_in_block_prob
self
.
block_dataset
=
block_dataset
self
.
title_dataset
=
title_dataset
self
.
short_seq_prob
=
short_seq_prob
self
.
rng
=
random
.
Random
(
self
.
seed
)
self
.
use_titles
=
use_titles
self
.
use_one_sent_docs
=
use_one_sent_docs
self
.
samples_mapping
=
get_block_samples_mapping
(
block_dataset
,
title_dataset
,
data_prefix
,
num_epochs
,
max_num_samples
,
max_seq_length
,
seed
,
name
,
use_one_sent_docs
)
self
.
tokenizer
=
get_tokenizer
()
self
.
vocab_id_list
=
list
(
self
.
tokenizer
.
inv_vocab
.
keys
())
self
.
vocab_id_to_token_list
=
self
.
tokenizer
.
inv_vocab
self
.
cls_id
=
self
.
tokenizer
.
cls
self
.
sep_id
=
self
.
tokenizer
.
sep
self
.
mask_id
=
self
.
tokenizer
.
mask
self
.
pad_id
=
self
.
tokenizer
.
pad
def
__len__
(
self
):
return
self
.
samples_mapping
.
shape
[
0
]
def
__getitem__
(
self
,
idx
):
"""Get an ICT example of a pseudo-query and the block of text from which it was extracted"""
sample_data
=
self
.
samples_mapping
[
idx
]
start_idx
,
end_idx
,
doc_idx
,
block_idx
=
sample_data
.
as_tuple
()
if
self
.
use_titles
:
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
title_pad_offset
=
3
+
len
(
title
)
else
:
title
=
None
title_pad_offset
=
2
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
assert
len
(
block
)
>
1
or
self
.
use_one_sent_docs
or
self
.
query_in_block_prob
==
1
# randint() is inclusive for Python rng
rand_sent_idx
=
self
.
rng
.
randint
(
0
,
len
(
block
)
-
1
)
# keep the query in the context query_in_block_prob fraction of the time.
if
self
.
rng
.
random
()
<
self
.
query_in_block_prob
:
query
=
block
[
rand_sent_idx
].
copy
()
else
:
query
=
block
.
pop
(
rand_sent_idx
)
# still need to truncate because blocks are concluded when
# the sentence lengths have exceeded max_seq_length.
query
=
query
[:
self
.
max_seq_length
-
2
]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
title_pad_offset
]
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_data
=
sample_data
.
as_array
()
sample
=
{
'query_tokens'
:
query_tokens
,
'query_pad_mask'
:
query_pad_mask
,
'block_tokens'
:
block_tokens
,
'block_pad_mask'
:
block_pad_mask
,
'block_data'
:
block_data
,
}
return
sample
def
get_block
(
self
,
start_idx
,
end_idx
,
doc_idx
):
"""Get the IDs for an evidence block plus the title of the corresponding document"""
block
=
[
self
.
block_dataset
[
i
]
for
i
in
range
(
start_idx
,
end_idx
)]
title
=
self
.
title_dataset
[
int
(
doc_idx
)]
block
=
list
(
itertools
.
chain
(
*
block
))[:
self
.
max_seq_length
-
(
3
+
len
(
title
))]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
def
get_null_block
(
self
):
"""Get empty block and title - used in REALM pretraining"""
block
,
title
=
[],
[]
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
return
block_tokens
,
block_pad_mask
def
concat_and_pad_tokens
(
self
,
tokens
,
title
=
None
):
"""Concat with special tokens and pad sequence to self.max_seq_length"""
tokens
=
list
(
tokens
)
if
title
is
None
:
tokens
=
[
self
.
cls_id
]
+
tokens
+
[
self
.
sep_id
]
else
:
title
=
list
(
title
)
tokens
=
[
self
.
cls_id
]
+
title
+
[
self
.
sep_id
]
+
tokens
+
[
self
.
sep_id
]
assert
len
(
tokens
)
<=
self
.
max_seq_length
num_pad
=
self
.
max_seq_length
-
len
(
tokens
)
pad_mask
=
[
1
]
*
len
(
tokens
)
+
[
0
]
*
num_pad
tokens
+=
[
self
.
pad_id
]
*
num_pad
return
np
.
array
(
tokens
),
np
.
array
(
pad_mask
)
megatron/data/realm_index.py
View file @
b9bd1a11
from
collections
import
defaultdict
import
itertools
import
itertools
import
os
import
os
import
pickle
import
pickle
...
@@ -8,7 +7,7 @@ import faiss
...
@@ -8,7 +7,7 @@ import faiss
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
megatron
import
get_args
,
mpu
from
megatron
import
get_args
def
detach
(
tensor
):
def
detach
(
tensor
):
...
@@ -17,7 +16,7 @@ def detach(tensor):
...
@@ -17,7 +16,7 @@ def detach(tensor):
class
BlockData
(
object
):
class
BlockData
(
object
):
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
"""Serializable data structure for holding data for blocks -- embeddings and necessary metadata for REALM"""
def
__init__
(
self
,
block_data_path
=
None
,
rank
=
None
):
def
__init__
(
self
,
block_data_path
=
None
,
load_from_path
=
True
,
rank
=
None
):
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
self
.
meta_data
=
dict
()
if
block_data_path
is
None
:
if
block_data_path
is
None
:
...
@@ -27,6 +26,9 @@ class BlockData(object):
...
@@ -27,6 +26,9 @@ class BlockData(object):
self
.
block_data_path
=
block_data_path
self
.
block_data_path
=
block_data_path
self
.
rank
=
rank
self
.
rank
=
rank
if
load_from_path
:
self
.
load_from_file
()
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data_path
)[
0
]
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
...
@@ -43,18 +45,23 @@ class BlockData(object):
...
@@ -43,18 +45,23 @@ class BlockData(object):
"""
"""
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
@
classmethod
def
load_from_file
(
self
):
def
load_from_file
(
cls
,
fname
):
"""Populate members from instance saved to file"""
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
print
(
"
\n
> Unpickling BlockData"
,
flush
=
True
)
state_dict
=
pickle
.
load
(
open
(
fname
,
'rb'
))
state_dict
=
pickle
.
load
(
open
(
self
.
block_data_path
,
'rb'
))
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
print
(
">> Finished unpickling BlockData
\n
"
,
flush
=
True
)
new_index
=
cls
()
self
.
embed_data
=
state_dict
[
'embed_data'
]
new_index
.
embed_data
=
state_dict
[
'embed_data'
]
self
.
meta_data
=
state_dict
[
'meta_data'
]
new_index
.
meta_data
=
state_dict
[
'meta_data'
]
return
new_index
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
def
add_block_data
(
self
,
block_indices
,
block_embeds
,
block_metas
,
allow_overwrite
=
False
):
"""Add data for set of blocks
:param block_indices: 1D array of unique int ids for the blocks
:param block_embeds: 2D array of embeddings of the blocks
:param block_metas: 2D array of metadata for the blocks.
In the case of REALM this will be [start_idx, end_idx, doc_idx]
"""
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_metas
):
for
idx
,
embed
,
meta
in
zip
(
block_indices
,
block_embeds
,
block_metas
):
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
if
not
allow_overwrite
and
idx
in
self
.
embed_data
:
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
raise
ValueError
(
"Unexpectedly tried to overwrite block data"
)
...
@@ -63,6 +70,7 @@ class BlockData(object):
...
@@ -63,6 +70,7 @@ class BlockData(object):
self
.
meta_data
[
idx
]
=
meta
self
.
meta_data
[
idx
]
=
meta
def
save_shard
(
self
):
def
save_shard
(
self
):
"""Save the block data that was created this in this process"""
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
os
.
makedirs
(
self
.
temp_dir_name
,
exist_ok
=
True
)
...
@@ -104,9 +112,9 @@ class BlockData(object):
...
@@ -104,9 +112,9 @@ class BlockData(object):
class
FaissMIPSIndex
(
object
):
class
FaissMIPSIndex
(
object
):
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
"""Wrapper object for a BlockData which similarity search via FAISS under the hood"""
def
__init__
(
self
,
index_type
,
embed_size
,
use_gpu
=
False
):
def
__init__
(
self
,
embed_size
,
block_data
=
None
,
use_gpu
=
False
):
self
.
index_type
=
index_type
self
.
embed_size
=
embed_size
self
.
embed_size
=
embed_size
self
.
block_data
=
block_data
self
.
use_gpu
=
use_gpu
self
.
use_gpu
=
use_gpu
self
.
id_map
=
dict
()
self
.
id_map
=
dict
()
...
@@ -114,10 +122,7 @@ class FaissMIPSIndex(object):
...
@@ -114,10 +122,7 @@ class FaissMIPSIndex(object):
self
.
_set_block_index
()
self
.
_set_block_index
()
def
_set_block_index
(
self
):
def
_set_block_index
(
self
):
INDEX_TYPES
=
[
'flat_ip'
]
"""Create a Faiss Flat index with inner product as the metric to search against"""
if
self
.
index_type
not
in
INDEX_TYPES
:
raise
ValueError
(
"Invalid index type specified"
)
print
(
"
\n
> Building index"
,
flush
=
True
)
print
(
"
\n
> Building index"
,
flush
=
True
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
self
.
block_mips_index
=
faiss
.
index_factory
(
self
.
embed_size
,
'Flat'
,
faiss
.
METRIC_INNER_PRODUCT
)
...
@@ -129,29 +134,52 @@ class FaissMIPSIndex(object):
...
@@ -129,29 +134,52 @@ class FaissMIPSIndex(object):
config
.
useFloat16
=
True
config
.
useFloat16
=
True
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
self
.
block_mips_index
=
faiss
.
GpuIndexFlat
(
res
,
self
.
block_mips_index
,
config
)
print
(
">>
> Finished building
index on GPU {}
\n
"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
print
(
">>
Initialized
index on GPU {}"
.
format
(
self
.
block_mips_index
.
getDevice
()),
flush
=
True
)
else
:
else
:
# CPU index supports IDs so wrap with IDMap
# CPU index supports IDs so wrap with IDMap
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
self
.
block_mips_index
=
faiss
.
IndexIDMap
(
self
.
block_mips_index
)
print
(
">> Finished building index
\n
"
,
flush
=
True
)
print
(
">> Initialized index on CPU"
,
flush
=
True
)
# if we were constructed with a BlockData, then automatically load it when the FAISS structure is built
if
self
.
block_data
is
not
None
:
self
.
add_block_embed_data
(
self
.
block_data
)
def
reset_index
(
self
):
def
reset_index
(
self
):
"""Delete existing index and create anew"""
"""Delete existing index and create anew"""
del
self
.
block_mips_index
del
self
.
block_mips_index
# reset the block data so that _set_block_index will reload it as well
if
self
.
block_data
is
not
None
:
block_data_path
=
self
.
block_data
.
block_data_path
del
self
.
block_data
self
.
block_data
=
BlockData
.
load_from_file
(
block_data_path
)
self
.
_set_block_index
()
self
.
_set_block_index
()
def
add_block_embed_data
(
self
,
all_block_data
):
def
add_block_embed_data
(
self
,
all_block_data
):
"""Add the embedding of each block to the underlying FAISS index"""
"""Add the embedding of each block to the underlying FAISS index"""
# this assumes the embed_data is a dict : {int: np.array<float>}
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
block_indices
,
block_embeds
=
zip
(
*
all_block_data
.
embed_data
.
items
())
# the embeddings have to be entered in as float32 even though the math internally is done with float16.
block_embeds_arr
=
np
.
float32
(
np
.
array
(
block_embeds
))
block_indices_arr
=
np
.
array
(
block_indices
)
# faiss GpuIndex doesn't work with IDMap wrapper so store ids to map back with
if
self
.
use_gpu
:
if
self
.
use_gpu
:
for
i
,
idx
in
enumerate
(
block_indices
):
for
i
,
idx
in
enumerate
(
block_indices
):
self
.
id_map
[
i
]
=
idx
self
.
id_map
[
i
]
=
idx
# we no longer need the embedding data since it's in the index now
all_block_data
.
clear
()
all_block_data
.
clear
()
if
self
.
use_gpu
:
if
self
.
use_gpu
:
self
.
block_mips_index
.
add
(
np
.
float32
(
np
.
array
(
block_embeds
))
)
self
.
block_mips_index
.
add
(
block_embeds
_arr
)
else
:
else
:
self
.
block_mips_index
.
add_with_ids
(
np
.
float32
(
np
.
array
(
block_embeds
)),
np
.
array
(
block_indices
))
self
.
block_mips_index
.
add_with_ids
(
block_embeds_arr
,
block_indices_arr
)
print
(
">>> Finished adding block data to index"
,
flush
=
True
)
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
def
search_mips_index
(
self
,
query_embeds
,
top_k
,
reconstruct
=
True
):
"""Get the top-k blocks by the index distance metric.
"""Get the top-k blocks by the index distance metric.
...
@@ -160,12 +188,15 @@ class FaissMIPSIndex(object):
...
@@ -160,12 +188,15 @@ class FaissMIPSIndex(object):
if False: return [num_queries x k] array of distances, and another for indices
if False: return [num_queries x k] array of distances, and another for indices
"""
"""
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
query_embeds
=
np
.
float32
(
detach
(
query_embeds
))
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
reconstruct
:
if
reconstruct
:
# get the vectors themselves
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
top_k_block_embeds
=
self
.
block_mips_index
.
search_and_reconstruct
(
query_embeds
,
top_k
)
return
top_k_block_embeds
return
top_k_block_embeds
else
:
else
:
# get distances and indices of closest vectors
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
distances
,
block_indices
=
self
.
block_mips_index
.
search
(
query_embeds
,
top_k
)
if
self
.
use_gpu
:
if
self
.
use_gpu
:
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
fresh_indices
=
np
.
zeros
(
block_indices
.
shape
)
...
...
megatron/tokenizer/tokenizer.py
View file @
b9bd1a11
...
@@ -31,11 +31,11 @@ def build_tokenizer(args):
...
@@ -31,11 +31,11 @@ def build_tokenizer(args):
# Select and instantiate the tokenizer.
# Select and instantiate the tokenizer.
assert
args
.
vocab_file
is
not
None
assert
args
.
vocab_file
is
not
None
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
if
args
.
tokenizer_type
==
'BertWordPieceLowerCase'
:
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
True
)
lower_case
=
True
)
elif
args
.
tokenizer_type
==
'BertWordPieceCase'
:
elif
args
.
tokenizer_type
==
'BertWordPieceCase'
:
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
tokenizer
=
_
BertWordPieceTokenizer
(
vocab_file
=
args
.
vocab_file
,
lower_case
=
False
)
lower_case
=
False
)
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
elif
args
.
tokenizer_type
==
'GPT2BPETokenizer'
:
assert
args
.
merge_file
is
not
None
assert
args
.
merge_file
is
not
None
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
tokenizer
=
_GPT2BPETokenizer
(
args
.
vocab_file
,
args
.
merge_file
)
...
@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
...
@@ -124,7 +124,7 @@ class AbstractTokenizer(ABC):
'tokenizer'
.
format
(
self
.
name
))
'tokenizer'
.
format
(
self
.
name
))
class
BertWordPieceTokenizer
(
AbstractTokenizer
):
class
_
BertWordPieceTokenizer
(
AbstractTokenizer
):
"""Original BERT wordpiece tokenizer."""
"""Original BERT wordpiece tokenizer."""
def
__init__
(
self
,
vocab_file
,
lower_case
=
True
):
def
__init__
(
self
,
vocab_file
,
lower_case
=
True
):
...
...
megatron/training.py
View file @
b9bd1a11
...
@@ -74,11 +74,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
...
@@ -74,11 +74,6 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
if
args
.
rank
==
0
and
args
.
cased_data_path
is
not
None
:
import
stanza
stanza
.
download
(
'en'
,
processors
=
{
'ner'
:
'conll03'
},
dir
=
'stanza'
)
# Model, optimizer, and learning rate.
# Model, optimizer, and learning rate.
timers
(
'model and optimizer'
).
start
()
timers
(
'model and optimizer'
).
start
()
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
lr_scheduler
=
setup_model_and_optimizer
(
model_provider
)
...
@@ -237,8 +232,6 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -237,8 +232,6 @@ def setup_model_and_optimizer(model_provider_func):
def
backward_step
(
optimizer
,
model
,
loss
):
def
backward_step
(
optimizer
,
model
,
loss
):
"""Backward step."""
"""Backward step."""
# if args.rank == 0:
# torch.save(lick)
args
=
get_args
()
args
=
get_args
()
timers
=
get_timers
()
timers
=
get_timers
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment