Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b8bb0b49
Commit
b8bb0b49
authored
Jun 15, 2020
by
Mohammad
Browse files
Debugging done on Circe
parent
e3c57c82
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
10 deletions
+18
-10
megatron/arguments.py
megatron/arguments.py
+1
-0
megatron/data/realm_dataset_utils.py
megatron/data/realm_dataset_utils.py
+4
-1
megatron/data/realm_index.py
megatron/data/realm_index.py
+10
-7
megatron/training.py
megatron/training.py
+3
-2
No files found.
megatron/arguments.py
View file @
b8bb0b49
...
@@ -391,6 +391,7 @@ def _add_data_args(parser):
...
@@ -391,6 +391,7 @@ def _add_data_args(parser):
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
)
group
.
add_argument
(
'--faiss-use-gpu'
,
action
=
'store_true'
)
group
.
add_argument
(
'--index-reload-interval'
,
type
=
int
,
default
=
500
)
group
.
add_argument
(
'--index-reload-interval'
,
type
=
int
,
default
=
500
)
group
.
add_argument
(
'--use-regular-masking'
,
action
=
'store_true'
)
group
.
add_argument
(
'--use-regular-masking'
,
action
=
'store_true'
)
group
.
add_argument
(
'--use-random-spans'
,
action
=
'store_true'
)
group
.
add_argument
(
'--allow-trivial-doc'
,
action
=
'store_true'
)
group
.
add_argument
(
'--allow-trivial-doc'
,
action
=
'store_true'
)
group
.
add_argument
(
'--ner-data-path'
,
type
=
str
,
default
=
None
)
group
.
add_argument
(
'--ner-data-path'
,
type
=
str
,
default
=
None
)
...
...
megatron/data/realm_dataset_utils.py
View file @
b8bb0b49
...
@@ -28,6 +28,9 @@ def build_realm_training_sample(sample, max_seq_length,
...
@@ -28,6 +28,9 @@ def build_realm_training_sample(sample, max_seq_length,
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
cls_id
,
sep_id
,
mask_id
,
max_predictions_per_seq
,
np_rng
)
elif
block_ner_mask
is
not
None
:
elif
block_ner_mask
is
not
None
:
block_ner_mask
=
list
(
itertools
.
chain
(
*
block_ner_mask
))[:
max_seq_length
-
2
]
block_ner_mask
=
list
(
itertools
.
chain
(
*
block_ner_mask
))[:
max_seq_length
-
2
]
if
args
.
use_random_spans
:
rand_idx
=
np
.
random
.
randint
(
len
(
block_ner_mask
))
block_ner_mask
=
block_ner_mask
[
rand_idx
:]
+
block_ner_mask
[:
rand_idx
]
block_ner_mask
=
[
0
]
+
block_ner_mask
+
[
0
]
block_ner_mask
=
[
0
]
+
block_ner_mask
+
[
0
]
masked_tokens
,
masked_positions
,
masked_labels
=
get_arrays_using_ner_mask
(
tokens
,
block_ner_mask
,
mask_id
)
masked_tokens
,
masked_positions
,
masked_labels
=
get_arrays_using_ner_mask
(
tokens
,
block_ner_mask
,
mask_id
)
else
:
else
:
...
@@ -182,7 +185,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
...
@@ -182,7 +185,7 @@ def get_block_samples_mapping(block_dataset, title_dataset, data_prefix, num_epo
indexmap_filename
+=
'.npy'
indexmap_filename
+=
'.npy'
# Build the indexed mapping if not exist.
# Build the indexed mapping if not exist.
if
torch
.
distributed
.
get
_rank
()
==
0
and
\
if
mpu
.
get_data_parallel
_rank
()
==
0
and
\
not
os
.
path
.
isfile
(
indexmap_filename
):
not
os
.
path
.
isfile
(
indexmap_filename
):
print
(
' > WARNING: could not find index map file {}, building '
print
(
' > WARNING: could not find index map file {}, building '
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
'the indices on rank 0 ...'
.
format
(
indexmap_filename
))
...
...
megatron/data/realm_index.py
View file @
b8bb0b49
...
@@ -15,12 +15,16 @@ def detach(tensor):
...
@@ -15,12 +15,16 @@ def detach(tensor):
class
BlockData
(
object
):
class
BlockData
(
object
):
def
__init__
(
self
):
def
__init__
(
self
,
block_data_path
=
None
):
args
=
get_args
()
self
.
embed_data
=
dict
()
self
.
embed_data
=
dict
()
self
.
meta_data
=
dict
()
self
.
meta_data
=
dict
()
block_data_path
=
os
.
path
.
splitext
(
args
.
block_data_path
)[
0
]
if
block_data_path
is
None
:
self
.
temp_dir_name
=
block_data_path
+
'_tmp'
args
=
get_args
()
block_data_path
=
args
.
block_data_path
self
.
block_data_path
=
block_data_path
block_data_name
=
os
.
path
.
splitext
(
self
.
block_data_path
)[
0
]
self
.
temp_dir_name
=
block_data_name
+
'_tmp'
def
state
(
self
):
def
state
(
self
):
return
{
return
{
...
@@ -54,7 +58,7 @@ class BlockData(object):
...
@@ -54,7 +58,7 @@ class BlockData(object):
def
save_shard
(
self
,
rank
):
def
save_shard
(
self
,
rank
):
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
if
not
os
.
path
.
isdir
(
self
.
temp_dir_name
):
os
.
m
k
dir
(
self
.
temp_dir_name
)
os
.
m
ake
dir
s
(
self
.
temp_dir_name
,
exist_ok
=
True
)
# save the data for each shard
# save the data for each shard
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
rank
),
'wb'
)
as
data_file
:
with
open
(
'{}/{}.pkl'
.
format
(
self
.
temp_dir_name
,
rank
),
'wb'
)
as
data_file
:
...
@@ -73,8 +77,7 @@ class BlockData(object):
...
@@ -73,8 +77,7 @@ class BlockData(object):
self
.
meta_data
.
update
(
data
[
'meta_data'
])
self
.
meta_data
.
update
(
data
[
'meta_data'
])
# assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
# assert (len(self.embed_data) == old_size + shard_size) or (str(ignore_shard) in fname)
args
=
get_args
()
with
open
(
self
.
block_data_path
,
'wb'
)
as
final_file
:
with
open
(
args
.
block_data_path
,
'wb'
)
as
final_file
:
pickle
.
dump
(
self
.
state
(),
final_file
)
pickle
.
dump
(
self
.
state
(),
final_file
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
shutil
.
rmtree
(
self
.
temp_dir_name
,
ignore_errors
=
True
)
...
...
megatron/training.py
View file @
b8bb0b49
...
@@ -422,8 +422,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
...
@@ -422,8 +422,9 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
elif
iteration
<
20
:
elif
iteration
<
20
:
print
(
"moving right along"
,
flush
=
True
)
#print("moving right along", flush=True)
report_memory
(
"iteration {}"
.
format
(
iteration
))
#report_memory("iteration {}".format(iteration))
pass
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
loss_dict
,
skipped_iter
=
train_step
(
forward_step_func
,
train_data_iterator
,
train_data_iterator
,
model
,
model
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment