Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d7509658
"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "64de7843281379f21aba7aa485adea04a6a2056b"
Commit
d7509658
authored
Jun 29, 2021
by
zihanl
Browse files
change folder name and add dialog training
parent
6f72a285
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
93 additions
and
51 deletions
+93
-51
dialogctrl/dialog_dataset.py
dialogctrl/dialog_dataset.py
+22
-9
dialogctrl/ner/gen_entityctrl_data.py
dialogctrl/ner/gen_entityctrl_data.py
+0
-0
dialogctrl/ner/ner_demo.py
dialogctrl/ner/ner_demo.py
+0
-0
dialogctrl/ner/run_command.sh
dialogctrl/ner/run_command.sh
+0
-0
dialogctrl/ner/src/config.py
dialogctrl/ner/src/config.py
+0
-0
dialogctrl/ner/src/dataloader.py
dialogctrl/ner/src/dataloader.py
+0
-0
dialogctrl/ner/src/metrics.py
dialogctrl/ner/src/metrics.py
+0
-0
dialogctrl/ner/src/model.py
dialogctrl/ner/src/model.py
+0
-0
dialogctrl/ner/src/trainer.py
dialogctrl/ner/src/trainer.py
+0
-0
dialogctrl/ner/src/utils.py
dialogctrl/ner/src/utils.py
+0
-0
dialogctrl/ner/train_ner.py
dialogctrl/ner/train_ner.py
+0
-0
dialogctrl/utils.py
dialogctrl/utils.py
+14
-14
megatron/arguments.py
megatron/arguments.py
+4
-1
megatron/tokenizer/tokenizer.py
megatron/tokenizer/tokenizer.py
+5
-4
megatron/training.py
megatron/training.py
+42
-19
train_gpt_conv.py
train_gpt_conv.py
+6
-4
No files found.
dialog
_
ctrl/dialog_dataset.py
→
dialogctrl/dialog_dataset.py
View file @
d7509658
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
numpy
as
np
import
numpy
as
np
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
print_rank_0
def
read_data
(
tokenizer
,
data_path
,
train_module
):
def
read_data
(
tokenizer
,
data_path
,
train_module
):
"""read and tokenize dialog data"""
"""read and tokenize dialog data"""
...
@@ -24,10 +25,17 @@ def read_data(tokenizer, data_path, train_module):
...
@@ -24,10 +25,17 @@ def read_data(tokenizer, data_path, train_module):
# only take the last three turns in the dialog context
# only take the last three turns in the dialog context
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
turns
[
-
3
:]
turns
=
turns
[
-
3
:]
context
=
" [SEP] "
.
join
(
turns
)
input_ids
=
tokenizer
.
tokenize
(
context
)
# input_ids
for
idx
,
turn
in
enumerate
(
turns
):
if
idx
==
0
:
input_ids
=
tokenizer
.
tokenize
(
turn
)
else
:
input_ids
.
extend
([
tokenizer
.
sep_id
]
+
tokenizer
.
tokenize
(
turn
))
# output_ids
output_ids
=
tokenizer
.
tokenize
(
response
)
output_ids
=
tokenizer
.
tokenize
(
response
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
elif
train_module
==
"control"
:
elif
train_module
==
"control"
:
...
@@ -40,14 +48,19 @@ def read_data(tokenizer, data_path, train_module):
...
@@ -40,14 +48,19 @@ def read_data(tokenizer, data_path, train_module):
turns
=
dialog_context
.
split
(
" [SEP] "
)
turns
=
dialog_context
.
split
(
" [SEP] "
)
last_turn
=
turns
[
-
1
]
last_turn
=
turns
[
-
1
]
# input_ids
if
ctrl_code
:
if
ctrl_code
:
inputs
=
last_turn
+
" [CTRL] "
+
ctrl_code
input_ids
=
tokenizer
.
tokenize
(
last_turn
)
ctrl_code_list
=
ctrl_code
.
split
(
" [CTRL] "
)
for
code
in
ctrl_code_list
:
input_ids
.
extend
([
tokenizer
.
ctrl_id
]
+
tokenizer
.
tokenize
(
code
))
else
:
else
:
inputs
=
last_turn
input_ids
=
tokenizer
.
tokenize
(
last_turn
)
outputs
=
ctrl_sent
input_ids
=
tokenizer
.
tokenize
(
inputs
)
# output_ids
outputs
=
ctrl_sent
output_ids
=
tokenizer
.
tokenize
(
outputs
)
output_ids
=
tokenizer
.
tokenize
(
outputs
)
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
data_list
.
append
({
"input_ids"
:
input_ids
,
"output_ids"
:
output_ids
})
else
:
else
:
...
@@ -68,7 +81,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
...
@@ -68,7 +81,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
def
__init__
(
self
,
data
,
max_seq_len
,
pad_id
,
eod_id
):
# need to deal with padding, label masking
# need to deal with padding, label masking
self
.
data
=
data
self
.
data
=
data
self
.
max_seq_len
self
.
max_seq_len
=
max_seq_len
self
.
pad_id
=
pad_id
self
.
pad_id
=
pad_id
self
.
eod_id
=
eod_id
self
.
eod_id
=
eod_id
...
@@ -79,7 +92,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
...
@@ -79,7 +92,7 @@ class ControlDialogDataset(torch.utils.data.Dataset):
data_dict
=
self
.
data
[
idx
]
data_dict
=
self
.
data
[
idx
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
input_ids
,
output_ids
=
data_dict
[
"input_ids"
],
data_dict
[
"output_ids"
]
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max
_
seq
_
len!"
assert
len
(
input_ids
)
<
self
.
max_seq_len
,
"Set a larger max
-
seq
-
len!"
# length_of_loss_mask == length_of_text - 1
# length_of_loss_mask == length_of_text - 1
text
=
input_ids
+
[
self
.
pad_id
]
+
output_ids
+
[
self
.
eod_id
]
text
=
input_ids
+
[
self
.
pad_id
]
+
output_ids
+
[
self
.
eod_id
]
...
@@ -118,4 +131,4 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
...
@@ -118,4 +131,4 @@ def build_train_valid_test_datasets(data_folder, dataset_name, train_module, max
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
valid_dataset
=
ControlDialogDataset
(
valid_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
test_dataset
=
ControlDialogDataset
(
test_data_list
,
max_seq_len
,
tokenizer
.
pad_id
,
tokenizer
.
eod_id
)
return
(
train_dataset
,
valid_dataset
,
test_dataset
)
return
train_dataset
,
valid_dataset
,
test_dataset
dialog
_
ctrl/ner/gen_entityctrl_data.py
→
dialogctrl/ner/gen_entityctrl_data.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/ner_demo.py
→
dialogctrl/ner/ner_demo.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/run_command.sh
→
dialogctrl/ner/run_command.sh
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/config.py
→
dialogctrl/ner/src/config.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/dataloader.py
→
dialogctrl/ner/src/dataloader.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/metrics.py
→
dialogctrl/ner/src/metrics.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/model.py
→
dialogctrl/ner/src/model.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/trainer.py
→
dialogctrl/ner/src/trainer.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/src/utils.py
→
dialogctrl/ner/src/utils.py
View file @
d7509658
File moved
dialog
_
ctrl/ner/train_ner.py
→
dialogctrl/ner/train_ner.py
View file @
d7509658
File moved
dialog
_
ctrl/utils.py
→
dialogctrl/utils.py
View file @
d7509658
...
@@ -16,20 +16,20 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
...
@@ -16,20 +16,20 @@ def get_ltor_attention_masks_and_position_ids(data, eod_token_id):
# reset attentino mask and position ids
# reset attentino mask and position ids
# Loop through the batches:
# Loop through the batches:
for
b
in
range
(
micro_batch_size
):
#
for b in range(micro_batch_size):
# Find indecies where EOD token is.
#
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
data
[
b
]
==
eod_token_id
]
#
eod_index = position_ids[b, data[b] == eod_token_id]
eod_index
=
eod_index
.
clone
()
#
eod_index = eod_index.clone()
# Loop through EOD indecies:
#
# Loop through EOD indecies:
prev_index
=
0
#
prev_index = 0
for
j
in
range
(
eod_index
.
size
()[
0
]):
#
for j in range(eod_index.size()[0]):
i
=
eod_index
[
j
]
#
i = eod_index[j]
# Mask attention loss.
#
# Mask attention loss.
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
#
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
#
# Reset positions.
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
#
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index
=
i
+
1
#
prev_index = i + 1
# Convert attention mask to binary:
# Convert attention mask to binary:
attention_mask
=
(
attention_mask
<
0.5
)
attention_mask
=
(
attention_mask
<
0.5
)
...
...
megatron/arguments.py
View file @
d7509658
...
@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -41,6 +41,7 @@ def parse_args(extra_args_provider=None, defaults={},
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_biencoder_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_vit_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_logging_args
(
parser
)
parser
=
_add_dialog_ctrl_args
(
parser
)
# Custom arguments.
# Custom arguments.
if
extra_args_provider
is
not
None
:
if
extra_args_provider
is
not
None
:
...
@@ -757,6 +758,8 @@ def _add_vit_args(parser):
...
@@ -757,6 +758,8 @@ def _add_vit_args(parser):
def
_add_dialog_ctrl_args
(
parser
):
def
_add_dialog_ctrl_args
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
"dialog control"
)
group
=
parser
.
add_argument_group
(
title
=
"dialog control"
)
group
.
add_argument
(
'--run-dialog'
,
action
=
'store_true'
,
help
=
'run dialog modeling'
)
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--train-module'
,
type
=
str
,
default
=
""
,
help
=
'either control module or dialogue model (control or dialog)'
)
help
=
'either control module or dialogue model (control or dialog)'
)
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
group
.
add_argument
(
'--data-folder'
,
type
=
str
,
default
=
""
,
...
@@ -765,7 +768,7 @@ def _add_dialog_ctrl_args(parser):
...
@@ -765,7 +768,7 @@ def _add_dialog_ctrl_args(parser):
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
help
=
'dataset name (e.g., wizard_of_wikipedia)'
)
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
group
.
add_argument
(
'--max-seq-len'
,
type
=
int
,
default
=
1024
,
help
=
'maximum sequence length'
)
help
=
'maximum sequence length'
)
group
.
add_argument
(
'--spec
_
toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
group
.
add_argument
(
'--spec
-
toks'
,
type
=
str
,
default
=
"[SEP],[CTRL],[PAD]"
,
help
=
'additional special tokens'
)
help
=
'additional special tokens'
)
return
parser
return
parser
megatron/tokenizer/tokenizer.py
View file @
d7509658
...
@@ -272,13 +272,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
...
@@ -272,13 +272,14 @@ class _GPT2BPETokenizer(AbstractTokenizer):
self
.
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merge_file
,
errors
=
'replace'
,
self
.
tokenizer
=
GPT2Tokenizer
(
vocab_file
,
merge_file
,
errors
=
'replace'
,
special_tokens
=
special_tokens
,
max_len
=
None
)
special_tokens
=
special_tokens
,
max_len
=
None
)
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
self
.
eod_id
=
self
.
tokenizer
.
encoder
[
'<|endoftext|>'
]
if
len
(
special_tokens
)
>
0
:
if
len
(
special_tokens
)
>
0
:
if
"[PAD]"
in
special_tokens
:
self
.
pad_id
=
self
.
tokenizer
.
encoder
[
'[PAD]'
]
if
"[SEP]"
in
special_tokens
:
if
"[SEP]"
in
special_tokens
:
self
.
sep_id
=
self
.
tokenizer
.
encoder
[
'[SEP]'
]
self
.
sep_id
=
self
.
tokenizer
.
special_tokens
[
'[SEP]'
]
if
"[CTRL]"
in
special_tokens
:
if
"[CTRL]"
in
special_tokens
:
self
.
ctrl_id
=
self
.
tokenizer
.
encoder
[
'[CTRL]'
]
self
.
ctrl_id
=
self
.
tokenizer
.
special_tokens
[
'[CTRL]'
]
if
"[PAD]"
in
special_tokens
:
self
.
pad_id
=
self
.
tokenizer
.
special_tokens
[
'[PAD]'
]
@
property
@
property
def
vocab_size
(
self
):
def
vocab_size
(
self
):
...
...
megatron/training.py
View file @
d7509658
...
@@ -53,7 +53,6 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
...
@@ -53,7 +53,6 @@ from megatron.schedules import forward_backward_pipelining_with_interleaving
from
megatron.utils
import
report_memory
from
megatron.utils
import
report_memory
def
print_datetime
(
string
):
def
print_datetime
(
string
):
"""Note that this call will sync across all ranks."""
"""Note that this call will sync across all ranks."""
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
...
@@ -325,6 +324,8 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -325,6 +324,8 @@ def setup_model_and_optimizer(model_provider_func):
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
start
()
timers
(
'load-checkpoint'
).
start
()
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
args
.
iteration
=
load_checkpoint
(
model
,
optimizer
,
lr_scheduler
)
# need to set train_samples to None
args
.
train_samples
=
None
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
timers
(
'load-checkpoint'
).
stop
()
timers
(
'load-checkpoint'
).
stop
()
timers
.
log
([
'load-checkpoint'
])
timers
.
log
([
'load-checkpoint'
])
...
@@ -792,28 +793,50 @@ def build_train_valid_test_data_iterators(
...
@@ -792,28 +793,50 @@ def build_train_valid_test_data_iterators(
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
consumed_valid_samples
=
(
args
.
iteration
//
args
.
eval_interval
)
*
\
args
.
eval_iters
*
args
.
global_batch_size
args
.
eval_iters
*
args
.
global_batch_size
if
args
.
run_dialog
:
args
.
consumed_train_samples
=
0
args
.
consumed_valid_samples
=
0
args
.
iteration
=
0
# Data loader only on rank 0 of each model parallel group.
# Data loader only on rank 0 of each model parallel group.
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
mpu
.
get_tensor_model_parallel_rank
()
==
0
:
if
args
.
run_dialog
:
# Build the datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
()
print_rank_0
(
' > datasets target sizes:'
)
train_size
=
len
(
train_ds
)
valid_size
=
len
(
valid_ds
)
test_size
=
len
(
test_ds
)
print_rank_0
(
' train: {}'
.
format
(
train_size
))
print_rank_0
(
' validation: {}'
.
format
(
valid_size
))
print_rank_0
(
' test: {}'
.
format
(
test_size
))
args
.
train_iters
=
train_size
//
args
.
global_batch_size
args
.
eval_iters
=
valid_size
//
args
.
global_batch_size
args
.
test_iters
=
test_size
//
args
.
global_batch_size
# Number of train/valid/test samples.
if
args
.
train_samples
:
train_samples
=
args
.
train_samples
else
:
else
:
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
# Number of train/valid/test samples.
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
if
args
.
train_samples
:
args
.
eval_iters
train_samples
=
args
.
train_samples
test_iters
=
args
.
eval_iters
else
:
train_val_test_num_samples
=
[
train_samples
,
train_samples
=
args
.
train_iters
*
args
.
global_batch_size
eval_iters
*
args
.
global_batch_size
,
eval_iters
=
(
args
.
train_iters
//
args
.
eval_interval
+
1
)
*
\
test_iters
*
args
.
global_batch_size
]
args
.
eval_iters
print_rank_0
(
' > datasets target sizes (minimum size):'
)
test_iters
=
args
.
eval_iters
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
train_val_test_num_samples
=
[
train_samples
,
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
eval_iters
*
args
.
global_batch_size
,
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
test_iters
*
args
.
global_batch_size
]
print_rank_0
(
' > datasets target sizes (minimum size):'
)
# Build the datasets.
print_rank_0
(
' train: {}'
.
format
(
train_val_test_num_samples
[
0
]))
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
(
print_rank_0
(
' validation: {}'
.
format
(
train_val_test_num_samples
[
1
]))
train_val_test_num_samples
)
print_rank_0
(
' test: {}'
.
format
(
train_val_test_num_samples
[
2
]))
# Build the datasets.
train_ds
,
valid_ds
,
test_ds
=
build_train_valid_test_datasets_provider
(
train_val_test_num_samples
)
# Build dataloders.
# Build dataloders.
train_dataloader
=
build_pretraining_data_loader
(
train_dataloader
=
build_pretraining_data_loader
(
...
...
train_
dialog_gpt
.py
→
train_
gpt_conv
.py
View file @
d7509658
...
@@ -9,11 +9,11 @@ from megatron import get_timers
...
@@ -9,11 +9,11 @@ from megatron import get_timers
from
megatron
import
get_tokenizer
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron
import
mpu
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
# from megatron.data.gpt_dataset import build_train_valid_test_datasets
from
dialog
_
ctrl.dialog_dataset
import
build_train_valid_test_datasets
from
dialogctrl.dialog_dataset
import
build_train_valid_test_datasets
from
megatron.model
import
GPTModel
from
megatron.model
import
GPTModel
from
megatron.training
import
pretrain
from
megatron.training
import
pretrain
# from megatron.utils import get_ltor_masks_and_position_ids
# from megatron.utils import get_ltor_masks_and_position_ids
from
dialog
_
ctrl.utils
import
get_ltor_attention_masks_and_position_ids
from
dialogctrl.utils
import
get_ltor_attention_masks_and_position_ids
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
...
@@ -52,7 +52,7 @@ def get_batch(data_iterator):
...
@@ -52,7 +52,7 @@ def get_batch(data_iterator):
loss_mask
=
data_b
[
'loss_mask'
].
float
()
loss_mask
=
data_b
[
'loss_mask'
].
float
()
# Get the attention_mask and postition ids.
# Get the attention_mask and postition ids.
attention_mask
s
,
position_ids
=
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
attention_mask
,
position_ids
=
get_ltor_attention_masks_and_position_ids
(
tokens
,
tokenizer
.
eod_id
)
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
return
tokens
,
labels
,
loss_mask
,
attention_mask
,
position_ids
...
@@ -86,7 +86,7 @@ def forward_step(data_iterator, model):
...
@@ -86,7 +86,7 @@ def forward_step(data_iterator, model):
def
train_valid_test_datasets_provider
():
def
train_valid_test_datasets_provider
():
"""Build train, valid, and test datasets for control module"""
"""Build train, valid, and test datasets for
dialog/
control module"""
args
=
get_args
()
args
=
get_args
()
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
print_rank_0
(
'> building train, validation, and test datasets for %s module ...'
%
args
.
train_module
)
...
@@ -99,6 +99,8 @@ def train_valid_test_datasets_provider():
...
@@ -99,6 +99,8 @@ def train_valid_test_datasets_provider():
seed
=
args
.
seed
)
seed
=
args
.
seed
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
print_rank_0
(
"> finished creating datasets for %s module ..."
%
args
.
train_module
)
return
train_ds
,
valid_ds
,
test_ds
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment