Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f223ff67
Commit
f223ff67
authored
Nov 27, 2019
by
Mohammad Shoeybi
Browse files
refactoring
parent
0f873f97
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
61 additions
and
60 deletions
+61
-60
generate_samples.py
generate_samples.py
+2
-2
megatron/utils.py
megatron/utils.py
+57
-0
pretrain_gpt2.py
pretrain_gpt2.py
+2
-58
No files found.
generate_samples.py
View file @
f223ff67
...
...
@@ -28,7 +28,7 @@ from arguments import get_args
from
megatron.utils
import
Timers
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
pretrain_gpt2
import
get_masks_and_position_ids
from
megatron.utils
import
get_
ltor_
masks_and_position_ids
from
megatron.utils
import
load_checkpoint
from
megatron.data_utils
import
make_tokenizer
from
configure_data
import
configure_data
...
...
@@ -91,7 +91,7 @@ def get_batch(context_tokens, args):
tokens
=
tokens
.
to
(
device
)
# Get the masks and postition ids.
attention_mask
,
loss_mask
,
position_ids
=
get_masks_and_position_ids
(
attention_mask
,
loss_mask
,
position_ids
=
get_
ltor_
masks_and_position_ids
(
tokens
,
args
.
eod_token
,
args
.
reset_position_ids
,
...
...
megatron/utils.py
View file @
f223ff67
...
...
@@ -31,6 +31,63 @@ from megatron.model import DistributedDataParallel as LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
def
get_ltor_masks_and_position_ids
(
data
,
eod_token
,
reset_position_ids
,
reset_attention_mask
,
eod_mask_loss
):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
batch_size
,
seq_length
=
data
.
size
()
# Attention mask (lower triangular).
if
reset_attention_mask
:
att_mask_batch
=
batch_size
else
:
att_mask_batch
=
1
attention_mask
=
torch
.
tril
(
torch
.
ones
(
(
att_mask_batch
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
att_mask_batch
,
1
,
seq_length
,
seq_length
)
# Loss mask.
loss_mask
=
torch
.
ones
(
data
.
size
(),
dtype
=
torch
.
float
,
device
=
data
.
device
)
if
eod_mask_loss
:
loss_mask
[
data
==
eod_token
]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# We need to clone as the ids will be modifed based on batch index.
if
reset_position_ids
:
position_ids
=
position_ids
.
clone
()
if
reset_position_ids
or
reset_attention_mask
:
# Loop through the batches:
for
b
in
range
(
batch_size
):
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
data
[
b
]
==
eod_token
]
# Detach indecies from positions if going to modify positions.
if
reset_position_ids
:
eod_index
=
eod_index
.
clone
()
# Loop through EOD indecies:
prev_index
=
0
for
j
in
range
(
eod_index
.
size
()[
0
]):
i
=
eod_index
[
j
]
# Mask attention loss.
if
reset_attention_mask
:
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
# Reset positions.
if
reset_position_ids
:
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
prev_index
=
i
+
1
return
attention_mask
,
loss_mask
,
position_ids
def
reduce_losses
(
losses
):
reduced_losses
=
torch
.
cat
(
[
loss
.
clone
().
detach
().
view
(
1
)
for
loss
in
losses
])
...
...
pretrain_gpt2.py
View file @
f223ff67
...
...
@@ -21,6 +21,7 @@ from configure_data import configure_data
from
gpt2_data_loader
import
make_gpt2_dataloaders
from
megatron
import
mpu
from
megatron.model
import
GPT2Model
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
print_rank_0
from
megatron.utils
import
reduce_losses
from
megatron.utils
import
vocab_size_with_padding
...
...
@@ -47,63 +48,6 @@ def model_provider(args):
return
model
def
get_masks_and_position_ids
(
data
,
eod_token
,
reset_position_ids
,
reset_attention_mask
,
eod_mask_loss
):
"""Build masks and position id."""
# Extract batch size and sequence length.
batch_size
,
seq_length
=
data
.
size
()
# Attention mask (lower triangular).
if
reset_attention_mask
:
att_mask_batch
=
batch_size
else
:
att_mask_batch
=
1
attention_mask
=
torch
.
tril
(
torch
.
ones
(
(
att_mask_batch
,
seq_length
,
seq_length
),
device
=
data
.
device
)).
view
(
att_mask_batch
,
1
,
seq_length
,
seq_length
)
# Loss mask.
loss_mask
=
torch
.
ones
(
data
.
size
(),
dtype
=
torch
.
float
,
device
=
data
.
device
)
if
eod_mask_loss
:
loss_mask
[
data
==
eod_token
]
=
0.0
# Position ids.
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
data
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
data
)
# We need to clone as the ids will be modifed based on batch index.
if
reset_position_ids
:
position_ids
=
position_ids
.
clone
()
if
reset_position_ids
or
reset_attention_mask
:
# Loop through the batches:
for
b
in
range
(
batch_size
):
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
data
[
b
]
==
eod_token
]
# Detach indecies from positions if going to modify positions.
if
reset_position_ids
:
eod_index
=
eod_index
.
clone
()
# Loop through EOD indecies:
prev_index
=
0
for
j
in
range
(
eod_index
.
size
()[
0
]):
i
=
eod_index
[
j
]
# Mask attention loss.
if
reset_attention_mask
:
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
# Reset positions.
if
reset_position_ids
:
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
prev_index
=
i
+
1
return
attention_mask
,
loss_mask
,
position_ids
def
get_batch
(
data_iterator
,
args
,
timers
):
"""Generate a batch"""
...
...
@@ -126,7 +70,7 @@ def get_batch(data_iterator, args, timers):
tokens
=
tokens_
[:,
:
-
1
].
contiguous
()
# Get the masks and postition ids.
attention_mask
,
loss_mask
,
position_ids
=
get_masks_and_position_ids
(
attention_mask
,
loss_mask
,
position_ids
=
get_
ltor_
masks_and_position_ids
(
tokens
,
args
.
eod_token
,
args
.
reset_position_ids
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment