Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d7509658
Commit
d7509658
authored
Jun 29, 2021
by
zihanl
Browse files
change folder name and add dialog training
parent
6f72a285
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
93 additions
and
51 deletions
+93
-51
dialogctrl/dialog_dataset.py
dialogctrl/dialog_dataset.py
+22
-9
dialogctrl/ner/gen_entityctrl_data.py
dialogctrl/ner/gen_entityctrl_data.py
+0
-0
dialogctrl/ner/ner_demo.py
dialogctrl/ner/ner_demo.py
+0
-0
dialogctrl/ner/run_command.sh
dialogctrl/ner/run_command.sh
+0
-0
dialogctrl/ner/src/config.py
dialogctrl/ner/src/config.py
+0
-0
dialogctrl/ner/src/dataloader.py
dialogctrl/ner/src/dataloader.py
+0
-0
dialogctrl/ner/src/metrics.py
dialogctrl/ner/src/metrics.py
+0
-0
dialogctrl/ner/src/model.py
dialogctrl/ner/src/model.py
+0
-0
dialogctrl/ner/src/trainer.py
dialogctrl/ner/src/trainer.py
+0
-0
dialogctrl/ner/src/utils.py
dialogctrl/ner/src/utils.py
+0
-0
dialogctrl/ner/train_ner.py
dialogctrl/ner/train_ner.py
+0
-0
dialogctrl/utils.py
dialogctrl/utils.py
+14
-14
megatron/arguments.py
megatron/arguments.py
+4
-1
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+5
-4
megatron/training.py
megatron/training.py
+42
-19
train_gpt_conv.py
train_gpt_conv.py
+6
-4
No files found.
dialog
_
ctrl/dialog_dataset.py
→
dialogctrl/dialog_dataset.py
View file @
d7509658
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
def
read_data
(
tokenizer
,
data_path
,
train_module
):
def
read_data
(
tokenizer
,
data_path
,
train_module
):
"""read and tokenize dialog data"""
"""read and tokenize dialog data"""
...
@@ -24,10 +25,17 @@ def read_data(tokenizer, data_path, train_module):
...
@@ -24,10 +25,17 @@ def read_data(tokenizer, data_path, train_module):
# only take the last three turns in the dialog context
# only take the last three turns in the dialog context
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
turns
=
turns
[
-
3
:]
context
=
" [SEP] "
.
join
(
turns
)
input_ids
=
tokenizer
.
tokenize
(
context
)
# input_ids
for
idx
,
turn
in
enumerate
(
turns
):
if
idx
==
0
:
input_ids
=
tokenizer
.
tokenize
(
turn
)
else
:
input_ids
.
extend
([
tokenizer
.
sep_id
]
+
tokenizer
.
tokenize
(
turn
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
train_module
==
"control"
:
elif
train_module
==
"control"
:
...
@@ -40,14 +48,19 @@ def read_data(tokenizer, data_path, train_module):
...
@@ -40,14 +48,19 @@ def read_data(tokenizer, data_path, train_module):
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
dialog_context
.
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
last_turn
=
turns
[
-
1
]
# input_ids
if
ctrl_code
:
if
ctrl_code
:
inputs
=
last_turn
+
" [CTRL] "
+
ctrl_code
input_ids
=
tokenizer
.
tokenize
(
last_turn
)
ctrl_code_list
=
ctrl_code
.
split
(
" [CTRL] "
)
for
code
in
ctrl_code_list
:
input_ids
.
extend
([
tokenizer
.
ctrl_id
]
+
tokenizer
.
tokenize
(
code
))
else
:
else
:
inputs
=
last_turn
input_ids
=
tokenizer
.
tokenize
(
last_turn
)
outputs
=
ctrl_sent
input_ids
=
tokenizer
.
tokenize
(
inputs
)
# output_ids
outputs
=
ctrl_sent
output_ids
=
tokenizer
.
tokenize
(
outputs
)
output_ids
=
tokenizer
.
tokenize
(
outputs
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
else
:
...
@@ -68,7 +81,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
...
@@ -68,7 +81,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
# need to deal with padding, label masking
self
.
data
=
data
self
.
data
=
data
self
.
max_seq_len
self
.
max_seq_len
=
max_seq_len
self
.
pad_id
=
pad_id
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
self
.
eod_id
=
eod_id
...
@@ -79,7 +92,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
...
@@ -79,7 +92,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
data_dict
=
self
.
data
[
idx
]
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max
_
seq
_
len!"
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max
-
seq
-
len!"
# length_of_loss_mask == length_of_text - 1
# length_of_loss_mask == length_of_text - 1
text
=
input_ids
+
[
self
.
pad_id
]
+
output_ids
+
[
self
.
eod_id
]
text
=
input_ids
+
[
self
.
pad_id
]
+
output_ids
+
[
self
.
eod_id
]
...
@@ -118,4 +131,4 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
...
@@ -118,4 +131,4 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
return
train_dataset
,
valid_dataset
,
test_dataset
dialog
_
ctrl/ner/gen_entityctrl_data.py
→
dialogctrl/ner/gen_entityctrl_data.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/ner_demo.py
→
dialogctrl/ner/ner_demo.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/run_command.sh
→
dialogctrl/ner/run_command.sh
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/config.py
→
dialogctrl/ner/src/config.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/dataloader.py
→
dialogctrl/ner/src/dataloader.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/metrics.py
→
dialogctrl/ner/src/metrics.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/model.py
→
dialogctrl/ner/src/model.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/trainer.py
→
dialogctrl/ner/src/trainer.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/utils.py
→
dialogctrl/ner/src/utils.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/train_ner.py
→
dialogctrl/ner/train_ner.py
View file @
d7509658
File moved
dialog
_
ctrl/utils.py
→
dialogctrl/utils.py
View file @
d7509658
...
@@ -16,20 +16,20 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
...
@@ -16,20 +16,20 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
# reset attentino mask and position ids
# reset attentino mask and position ids
# Loop through the batches:
# Loop through the batches:
for
b
in
range
(
micro_batch_size
):
#
for b in range(micro_batch_size):
# Find indecies where EOD token is.
#
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
data
[
b
]
==
eod_token_id
]
#
eod_index = position_ids[b, data[b] == eod_token_id]
eod_index
=
eod_index
.
clone
()
#
eod_index = eod_index.clone()
# Loop through EOD indecies:
#
# Loop through EOD indecies:
prev_index
=
0
#
prev_index = 0
for
j
in
range
(
eod_index
.
size
()[
0
]):
#
for j in range(eod_index.size()[0]):
i
=
eod_index
[
j
]
#
i = eod_index[j]
# Mask attention loss.
#
# Mask attention loss.
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
#
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
#
# Reset positions.
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
#
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index
=
i
+
1
#
prev_index = i + 1
# Convert attention mask to binary:
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
attention_mask
=
(
attention_mask
<
0.5
)
...
...
megatron/arguments.py
View file @
d7509658
...
@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_dialog_ctrl_args
(
parser
)
# Custom arguments.
# Custom arguments.
if
extra_args_provider
is
not
None
:
if
extra_args_provider
is
not
None
:
...
@@ -757,6 +758,8 @@ def _add_vit_args(parser):
...
@@ -757,6 +758,8 @@ def _add_vit_args(parser):
def
_add_dialog_ctrl_args
(
parser
):
def
_add_dialog_ctrl_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"dialog control"
)
group
=
parser
.
add_argument_group
(
title
=
"dialog control"
)
group
.
add_argument
(
'--run-dialog'
,
action
=
'store_true'
,
help
=
'run dialog modeling'
)
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
help
=
'either control module or dialogue model (control or dialog)'
)
help
=
'either control module or dialogue model (control or dialog)'
)
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
...
@@ -765,7 +768,7 @@ def _add_dialog_ctrl_args(parser):
...
@@ -765,7 +768,7 @@ def _add_dialog_ctrl_args(parser):
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
help
=
'maximum sequence length'
)
help
=
'maximum sequence length'
)
group
.
add_argument
(
'--spec
_
toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
group
.
add_argument
(
'--spec
-
toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
help
=
'additional special tokens'
)
help
=
'additional special tokens'
)
return
parser
return
parser
megatron/tokenizer/tokenizer.py
View file @
d7509658
...
@@ -272,13 +272,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
...
@@ -272,13 +272,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
self
.
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merge_file
,
errors
=
'replace'
,
self
.
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merge_file
,
errors
=
'replace'
,
special_tokens
=
special_tokens
,
max_len
=
None
)
special_tokens
=
special_tokens
,
max_len
=
None
)
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
if
len
(
special_tokens
)
>
0
:
if
len
(
special_tokens
)
>
0
:
if
"[PAD]"
in
special_tokens
:
self
.
pad_id
=
self
.
tokenizer
.
encoder
[
'[PAD]'
]
if
"[SEP]"
in
special_tokens
:
if
"[SEP]"
in
special_tokens
:
self
.
sep_id
=
self
.
tokenizer
.
encoder
[
'[SEP]'
]
self
.
sep_id
=
self
.
tokenizer
.
special_tokens
[
'[SEP]'
]
if
"[CTRL]"
in
special_tokens
:
if
"[CTRL]"
in
special_tokens
:
self
.
ctrl_id
=
self
.
tokenizer
.
encoder
[
'[CTRL]'
]
self
.
ctrl_id
=
self
.
tokenizer
.
special_tokens
[
'[CTRL]'
]
if
"[PAD]"
in
special_tokens
:
self
.
pad_id
=
self
.
tokenizer
.
special_tokens
[
'[PAD]'
]
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
...
...
megatron/training.py
View file @
d7509658
...
@@ -53,7 +53,6 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
...
@@ -53,7 +53,6 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
...
@@ -325,6 +324,8 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -325,6 +324,8 @@ def setup_model_and_optimizer(model_provider_func):
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
timers
(
'load-checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
# need to set train_samples to None
args
.
train_samples
=
None
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
()
timers
(
'load-checkpoint'
).
stop
()
timers
.
log
([
'load-checkpoint'
])
timers
.
log
([
'load-checkpoint'
])
...
@@ -792,9 +793,31 @@ def build_train_valid_test_data_iterators(
...
@@ -792,9 +793,31 @@ def build_train_valid_test_data_iterators(
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
*
args
.
global_batch_size
args
.
eval_iters
*
args
.
global_batch_size
if
args
.
run_dialog
:
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
args
.
iteration
=
0
# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
args
.
run_dialog
:
# Build the datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
()
print_rank_0
(
' > datasets target sizes:'
)
train_size
=
len
(
train_ds
)
valid_size
=
len
(
valid_ds
)
test_size
=
len
(
test_ds
)
print_rank_0
(
' train: {}'
.
format
(
train_size
))
print_rank_0
(
' validation: {}'
.
format
(
valid_size
))
print_rank_0
(
' test: {}'
.
format
(
test_size
))
args
.
train_iters
=
train_size
//
args
.
global_batch_size
args
.
eval_iters
=
valid_size
//
args
.
global_batch_size
args
.
test_iters
=
test_size
//
args
.
global_batch_size
else
:
# Number of train/valid/test samples.
# Number of train/valid/test samples.
if
args
.
train_samples
:
if
args
.
train_samples
:
train_samples
=
args
.
train_samples
train_samples
=
args
.
train_samples
...
...
train_
dialog_gpt
.py
→
train_
gpt_conv
.py
View file @
d7509658
...
@@ -9,11 +9,11 @@ from megatron import get_timers
...
@@ -9,11 +9,11 @@ from megatron import get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
from
dialog
_
ctrl.dialog_dataset
import
build_train_valid_test_datasets
from
dialogctrl.dialog_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
# from megatron.utils import get_ltor_masks_and_position_ids
# from megatron.utils import get_ltor_masks_and_position_ids
from
dialog
_
ctrl.utils
import
get_ltor_attention_masks_and_position_ids
from
dialogctrl.utils
import
get_ltor_attention_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
@@ -52,7 +52,7 @@ def get_batch(data_iterator):
...
@@ -52,7 +52,7 @@ def get_batch(data_iterator):
loss_mask
=
data_b
[
'loss_mask'
].
float
()
loss_mask
=
data_b
[
'loss_mask'
].
float
()
# Get the attention_mask and postition ids.
# Get the attention_mask and postition ids.
attention_mask
s
,
position_ids
=
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
attention_mask
,
position_ids
=
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
...
@@ -86,7 +86,7 @@ def forward_step(data_iterator, model):
...
@@ -86,7 +86,7 @@ def forward_step(data_iterator, model):
def
train_valid_test_datasets_provider
():
def
train_valid_test_datasets_provider
():
"""Build train, valid, and test datasets for control module"""
"""Build train, valid, and test datasets for
dialog/
control module"""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
...
@@ -99,6 +99,8 @@ def train_valid_test_datasets_provider():
...
@@ -99,6 +99,8 @@ def train_valid_test_datasets_provider():
seed
=
args
.
seed
)
seed
=
args
.
seed
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
return
train_ds
,
valid_ds
,
test_ds
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment