Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
8ceed7c7
Commit
8ceed7c7
authored
Apr 14, 2020
by
Mohammad
Browse files
changed gpt2 masking to binary and masked_fill
parent
c0a59a66
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
7 additions
and
13 deletions
+7
-13
megatron/model/gpt2_model.py
megatron/model/gpt2_model.py
+1
-2
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+1
-2
megatron/utils.py
megatron/utils.py
+3
-5
pretrain_gpt2.py
pretrain_gpt2.py
+1
-2
tasks/zeroshot_gpt2/evaluate.py
tasks/zeroshot_gpt2/evaluate.py
+1
-2
No files found.
megatron/model/gpt2_model.py
View file @
8ceed7c7
...
...
@@ -27,8 +27,7 @@ from .utils import scaled_init_method_normal
def
gpt2_attention_mask_func
(
attention_scores
,
ltor_mask
):
attention_scores
=
torch
.
mul
(
attention_scores
,
ltor_mask
)
-
\
10000.0
*
(
1.0
-
ltor_mask
)
attention_scores
.
masked_fill_
(
ltor_mask
,
-
10000.0
)
return
attention_scores
...
...
megatron/text_generation_utils.py
View file @
8ceed7c7
...
...
@@ -42,8 +42,7 @@ def get_batch(context_tokens):
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
,
args
.
fp16
)
args
.
eod_mask_loss
)
return
tokens
,
attention_mask
,
position_ids
...
...
megatron/utils.py
View file @
8ceed7c7
...
...
@@ -119,8 +119,7 @@ def get_ltor_masks_and_position_ids(data,
eod_token
,
reset_position_ids
,
reset_attention_mask
,
eod_mask_loss
,
fp16
):
eod_mask_loss
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
...
...
@@ -170,8 +169,7 @@ def get_ltor_masks_and_position_ids(data,
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
prev_index
=
i
+
1
# Convert
if
fp16
:
attention_mask
=
attention_mask
.
half
()
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
return
attention_mask
,
loss_mask
,
position_ids
pretrain_gpt2.py
View file @
8ceed7c7
...
...
@@ -65,8 +65,7 @@ def get_batch(data_iterator):
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
,
args
.
fp16
)
args
.
eod_mask_loss
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
...
...
tasks/zeroshot_gpt2/evaluate.py
View file @
8ceed7c7
...
...
@@ -71,8 +71,7 @@ def process_batch(batch):
tokenizer
.
eod
,
args
.
reset_position_ids
,
args
.
reset_attention_mask
,
args
.
eod_mask_loss
,
args
.
fp16
)
args
.
eod_mask_loss
)
return
tokens
,
labels
,
attention_mask
,
position_ids
,
loss_mask
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment