Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2a3b445d
Commit
2a3b445d
authored
Jun 26, 2020
by
Neel Kant
Browse files
Cosmetic changes
parent
ac967fa0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
megatron/data/realm_dataset.py
megatron/data/realm_dataset.py
+0
-1
pretrain_bert_ict.py
pretrain_bert_ict.py
+7
-6
No files found.
megatron/data/realm_dataset.py
View file @
2a3b445d
...
@@ -65,7 +65,6 @@ class ICTDataset(Dataset):
...
@@ -65,7 +65,6 @@ class ICTDataset(Dataset):
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
query_tokens
,
query_pad_mask
=
self
.
concat_and_pad_tokens
(
query
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
block_tokens
,
block_pad_mask
=
self
.
concat_and_pad_tokens
(
block
,
title
)
print
(
self
.
tokenizer
.
decode_token_ids
(
block_tokens
),
'
\n
'
)
block_data
=
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
block_data
=
np
.
array
([
start_idx
,
end_idx
,
doc_idx
,
block_idx
]).
astype
(
np
.
int64
)
sample
=
{
sample
=
{
...
...
pretrain_bert_ict.py
View file @
2a3b445d
...
@@ -33,8 +33,11 @@ num_batches = 0
...
@@ -33,8 +33,11 @@ num_batches = 0
def
general_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
def
general_model_provider
(
only_query_model
=
False
,
only_block_model
=
False
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
if
args
.
ict_head_size
is
None
:
assert
args
.
ict_head_size
is
not
None
,
\
raise
ValueError
(
"Need to specify --ict-head-size to provide an ICTBertModel"
)
"Need to specify --ict-head-size to provide an ICTBertModel"
assert
args
.
model_parallel_size
==
1
,
\
"Model parallel size > 1 not supported for ICT"
print_rank_0
(
'building ICTBertModel...'
)
print_rank_0
(
'building ICTBertModel...'
)
...
@@ -89,7 +92,6 @@ def forward_step(data_iterator, model):
...
@@ -89,7 +92,6 @@ def forward_step(data_iterator, model):
timers
(
'batch generator'
).
stop
()
timers
(
'batch generator'
).
stop
()
# Forward model.
# Forward model.
# retrieval_scores = model(query_tokens, query_pad_mask, block_tokens, block_pad_mask).float()
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
query_logits
,
block_logits
=
model
(
query_tokens
,
query_pad_mask
,
block_tokens
,
block_pad_mask
)
data_parallel_size
=
dist
.
get_world_size
()
/
args
.
model_parallel_size
data_parallel_size
=
dist
.
get_world_size
()
/
args
.
model_parallel_size
...
@@ -100,11 +102,11 @@ def forward_step(data_iterator, model):
...
@@ -100,11 +102,11 @@ def forward_step(data_iterator, model):
all_query_logits
=
torch
.
zeros
(
all_logits_shape
).
type
(
query_logits
.
dtype
).
cuda
()
all_query_logits
=
torch
.
zeros
(
all_logits_shape
).
type
(
query_logits
.
dtype
).
cuda
()
all_block_logits
=
all_query_logits
.
clone
().
cuda
()
all_block_logits
=
all_query_logits
.
clone
().
cuda
()
# record this processes' data
and then merge with other processes below
# record this processes' data
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_query_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
query_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
all_block_logits
[
args
.
rank
*
batch_size
:(
args
.
rank
+
1
)
*
batch_size
]
=
block_logits
#
currently this assumes model parallel size == 1.
#
merge data from all processes
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_query_logits
)
dist
.
all_reduce
(
all_block_logits
)
dist
.
all_reduce
(
all_block_logits
)
...
@@ -153,6 +155,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
...
@@ -153,6 +155,5 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
forward_step
,
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
args_defaults
=
{
'tokenizer_type'
:
'BertWordPieceLowerCase'
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment