Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4fd4bd9d
Unverified
Commit
4fd4bd9d
authored
Mar 23, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Mar 23, 2023
Browse files
[chatgpt] support instuct training (#3216)
parent
cd142fbe
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
313 additions
and
39 deletions
+313
-39
applications/ChatGPT/chatgpt/dataset/__init__.py
applications/ChatGPT/chatgpt/dataset/__init__.py
+2
-2
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
+120
-2
applications/ChatGPT/chatgpt/dataset/utils.py
applications/ChatGPT/chatgpt/dataset/utils.py
+15
-0
applications/ChatGPT/chatgpt/models/llama/__init__.py
applications/ChatGPT/chatgpt/models/llama/__init__.py
+2
-1
applications/ChatGPT/chatgpt/models/llama/llama_lm.py
applications/ChatGPT/chatgpt/models/llama/llama_lm.py
+38
-0
applications/ChatGPT/chatgpt/trainer/sft.py
applications/ChatGPT/chatgpt/trainer/sft.py
+24
-26
applications/ChatGPT/chatgpt/utils/__init__.py
applications/ChatGPT/chatgpt/utils/__init__.py
+3
-0
applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
+74
-0
applications/ChatGPT/examples/train_sft.py
applications/ChatGPT/examples/train_sft.py
+35
-8
No files found.
applications/ChatGPT/chatgpt/dataset/__init__.py
View file @
4fd4bd9d
from
.reward_dataset
import
RmStaticDataset
,
HhRlhfDataset
from
.utils
import
is_rank_0
from
.sft_dataset
import
SFTDataset
from
.sft_dataset
import
SFTDataset
,
AlpacaDataset
,
AlpacaDataCollator
__all__
=
[
'RmStaticDataset'
,
'HhRlhfDataset'
,
'is_rank_0'
,
'SFTDataset'
]
__all__
=
[
'RmStaticDataset'
,
'HhRlhfDataset'
,
'is_rank_0'
,
'SFTDataset'
,
'AlpacaDataset'
,
'AlpacaDataCollator'
]
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
View file @
4fd4bd9d
from
typing
import
Callable
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
Sequence
import
random
from
torch.utils.data
import
Dataset
import
torch.distributed
as
dist
from
tqdm
import
tqdm
import
torch
from
.utils
import
is_rank_0
from
.utils
import
is_rank_0
,
jload
import
transformers
from
colossalai.logging
import
get_dist_logger
logger
=
get_dist_logger
()
IGNORE_INDEX
=
-
100
PROMPT_DICT
=
{
"prompt_input"
:
(
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.
\n\n
"
"### Instruction:
\n
{instruction}
\n\n
### Input:
\n
{input}
\n\n
### Response:"
),
"prompt_no_input"
:
(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.
\n\n
"
"### Instruction:
\n
{instruction}
\n\n
### Response:"
),
}
class
SFTDataset
(
Dataset
):
"""
...
...
@@ -38,3 +72,87 @@ class SFTDataset(Dataset):
def
__getitem__
(
self
,
idx
):
return
self
.
prompts
[
idx
]
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
)
->
Dict
:
"""Tokenize a list of strings."""
tokenized_list
=
[
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
max_length
=
tokenizer
.
model_max_length
,
truncation
=
True
,
)
for
text
in
strings
]
input_ids
=
labels
=
[
tokenized
.
input_ids
[
0
]
for
tokenized
in
tokenized_list
]
input_ids_lens
=
labels_lens
=
[
tokenized
.
input_ids
.
ne
(
tokenizer
.
pad_token_id
).
sum
().
item
()
for
tokenized
in
tokenized_list
]
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
,
input_ids_lens
=
input_ids_lens
,
labels_lens
=
labels_lens
,
)
def
preprocess
(
sources
:
Sequence
[
str
],
targets
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
)
->
Dict
:
"""Preprocess the data by tokenizing."""
examples
=
[
s
+
t
for
s
,
t
in
zip
(
sources
,
targets
)]
examples_tokenized
,
sources_tokenized
=
[
_tokenize_fn
(
strings
,
tokenizer
)
for
strings
in
(
examples
,
sources
)]
input_ids
=
examples_tokenized
[
"input_ids"
]
labels
=
copy
.
deepcopy
(
input_ids
)
for
label
,
source_len
in
zip
(
labels
,
sources_tokenized
[
"input_ids_lens"
]):
label
[:
source_len
]
=
IGNORE_INDEX
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
)
class
AlpacaDataset
(
Dataset
):
"""Dataset for supervised fine-tuning."""
def
__init__
(
self
,
data_path
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
):
super
(
AlpacaDataset
,
self
).
__init__
()
logger
.
info
(
"Loading data..."
)
list_data_dict
=
jload
(
data_path
)
logger
.
info
(
"Formatting inputs..."
)
prompt_input
,
prompt_no_input
=
PROMPT_DICT
[
"prompt_input"
],
PROMPT_DICT
[
"prompt_no_input"
]
sources
=
[
prompt_input
.
format_map
(
example
)
if
example
.
get
(
"input"
,
""
)
!=
""
else
prompt_no_input
.
format_map
(
example
)
for
example
in
list_data_dict
]
targets
=
[
f
"
{
example
[
'output'
]
}{
tokenizer
.
eos_token
}
"
for
example
in
list_data_dict
]
logger
.
info
(
"Tokenizing inputs... This may take some time..."
)
data_dict
=
preprocess
(
sources
,
targets
,
tokenizer
)
self
.
input_ids
=
data_dict
[
"input_ids"
]
self
.
labels
=
data_dict
[
"labels"
]
def
__len__
(
self
):
return
len
(
self
.
input_ids
)
def
__getitem__
(
self
,
i
)
->
Dict
[
str
,
torch
.
Tensor
]:
return
dict
(
input_ids
=
self
.
input_ids
[
i
],
labels
=
self
.
labels
[
i
])
@
dataclass
class
AlpacaDataCollator
(
object
):
"""Collate examples for supervised fine-tuning."""
tokenizer
:
transformers
.
PreTrainedTokenizer
def
__call__
(
self
,
instances
:
Sequence
[
Dict
])
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
,
labels
=
tuple
([
instance
[
key
]
for
instance
in
instances
]
for
key
in
(
"input_ids"
,
"labels"
))
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_ids
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
)
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
,
attention_mask
=
input_ids
.
ne
(
self
.
tokenizer
.
pad_token_id
),
)
applications/ChatGPT/chatgpt/dataset/utils.py
View file @
4fd4bd9d
import
io
import
json
import
torch.distributed
as
dist
def
is_rank_0
()
->
bool
:
return
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
def
_make_r_io_base
(
f
,
mode
:
str
):
if
not
isinstance
(
f
,
io
.
IOBase
):
f
=
open
(
f
,
mode
=
mode
)
return
f
def
jload
(
f
,
mode
=
"r"
):
"""Load a .json file into a dictionary."""
f
=
_make_r_io_base
(
f
,
mode
)
jdict
=
json
.
load
(
f
)
f
.
close
()
return
jdict
\ No newline at end of file
applications/ChatGPT/chatgpt/models/llama/__init__.py
View file @
4fd4bd9d
from
.llama_actor
import
LlamaActor
from
.llama_critic
import
LlamaCritic
from
.llama_rm
import
LlamaRM
from
.llama_lm
import
LlamaLM
__all__
=
[
'LlamaActor'
,
'LlamaCritic'
,
'LlamaRM'
]
__all__
=
[
'LlamaActor'
,
'LlamaCritic'
,
'LlamaRM'
,
'LlamaLM'
]
applications/ChatGPT/chatgpt/models/llama/llama_lm.py
0 → 100644
View file @
4fd4bd9d
from
typing
import
Optional
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
from
..base
import
LM
class
LlamaLM
(
LM
):
"""
Llama language model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/trainer/sft.py
View file @
4fd4bd9d
...
...
@@ -2,7 +2,6 @@ from abc import ABC
from
typing
import
Optional
import
loralib
as
lora
import
torch
from
chatgpt.dataset
import
SFTDataset
from
chatgpt.models.loss
import
GPTLMLoss
from
torch.optim
import
Adam
,
Optimizer
from
torch.utils.data
import
DataLoader
...
...
@@ -22,8 +21,8 @@ class SFTTrainer(ABC):
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
train_data
set (SFTDataset or SFTDistributedDataset)
: the data
set
to use for training
eval_data
set (SFTDataset or SFTDistributedDataset)
: the data
set
to use for evaluation
train_data
loader
: the data
loader
to use for training
eval_data
loader
: the data
loader
to use for evaluation
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
...
...
@@ -34,8 +33,8 @@ class SFTTrainer(ABC):
model
,
strategy
:
Strategy
,
optim
:
Optimizer
,
train_data
set
:
SFTDataset
,
eval_data
set
:
SFTDataset
,
train_data
loader
:
DataLoader
,
eval_data
loader
:
DataLoader
=
None
,
sampler
:
Optional
[
DistributedSampler
]
=
None
,
batch_size
:
int
=
1
,
max_epochs
:
int
=
2
,
...
...
@@ -43,13 +42,10 @@ class SFTTrainer(ABC):
super
().
__init__
()
self
.
strategy
=
strategy
self
.
epochs
=
max_epochs
self
.
train_dataset
=
train_dataset
self
.
eval_dataset
=
eval_dataset
self
.
sampler
=
sampler
self
.
train_dataloader
=
DataLoader
(
self
.
train_dataset
,
shuffle
=
(
sampler
is
None
),
sampler
=
sampler
,
batch_size
=
batch_size
)
self
.
eval_dataloader
=
DataLoader
(
self
.
eval_dataset
,
batch_size
=
batch_size
)
self
.
train_dataloader
=
train_dataloader
self
.
eval_dataloader
=
eval_dataloader
self
.
model
=
strategy
.
setup_model
(
model
)
if
"DDP"
in
str
(
self
.
strategy
):
...
...
@@ -79,23 +75,25 @@ class SFTTrainer(ABC):
logger
.
info
(
f
'Train Epoch
{
epoch
}
/
{
self
.
epochs
}
Batch
{
batch_id
}
Rank
{
dist
.
get_rank
()
}
loss
{
loss
.
item
()
}
'
)
# eval
self
.
model
.
eval
()
with
torch
.
no_grad
():
loss_sum
=
0
num_seen
=
0
for
batch
in
self
.
eval_dataloader
:
prompt_ids
=
batch
[
"input_ids"
]
p_mask
=
batch
[
"attention_mask"
]
prompt_ids
=
prompt_ids
.
squeeze
(
1
).
cuda
()
p_mask
=
p_mask
.
squeeze
(
1
).
cuda
()
if
self
.
eval_dataloader
is
not
None
:
self
.
model
.
eval
()
with
torch
.
no_grad
():
loss_sum
=
0
num_seen
=
0
for
batch
in
self
.
eval_dataloader
:
prompt_ids
=
batch
[
"input_ids"
]
p_mask
=
batch
[
"attention_mask"
]
prompt_ids
=
prompt_ids
.
squeeze
(
1
).
cuda
()
p_mask
=
p_mask
.
squeeze
(
1
).
cuda
()
prompt_logits
=
self
.
model
(
prompt_ids
,
attention_mask
=
p_mask
)
loss
=
self
.
loss_fn
(
prompt_logits
,
prompt_ids
)
loss_sum
+=
loss
.
item
()
num_seen
+=
prompt_ids
.
size
(
0
)
prompt_logits
=
self
.
model
(
prompt_ids
,
attention_mask
=
p_mask
)
loss
=
self
.
loss_fn
(
prompt_logits
,
prompt_ids
)
loss_sum
+=
loss
.
item
()
num_seen
+=
prompt_ids
.
size
(
0
)
loss_mean
=
loss_sum
/
num_seen
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
'Eval Epoch
{
epoch
}
/
{
self
.
epochs
}
loss
{
loss_mean
}
'
)
loss_mean
=
loss_sum
/
num_seen
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
'Eval Epoch
{
epoch
}
/
{
self
.
epochs
}
loss
{
loss_mean
}
'
)
epoch_bar
.
update
()
applications/ChatGPT/chatgpt/utils/__init__.py
0 → 100644
View file @
4fd4bd9d
from
.tokenizer_utils
import
smart_tokenizer_and_embedding_resize
,
prepare_llama_tokenizer_and_embedding
__all__
=
[
'smart_tokenizer_and_embedding_resize'
,
'prepare_llama_tokenizer_and_embedding'
]
\ No newline at end of file
applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
0 → 100644
View file @
4fd4bd9d
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
import
transformers
DEFAULT_PAD_TOKEN
=
"[PAD]"
DEFAULT_EOS_TOKEN
=
"</s>"
DEFAULT_BOS_TOKEN
=
"</s>"
DEFAULT_UNK_TOKEN
=
"</s>"
def
prepare_llama_tokenizer_and_embedding
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
model
:
transformers
.
PreTrainedModel
,
special_tokens_dict
:
Dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
):
"""prepare llama tokenizer and embedding.
"""
if
tokenizer
.
pad_token
is
None
:
smart_tokenizer_and_embedding_resize
(
special_tokens_dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
tokenizer
=
tokenizer
,
model
=
model
,
)
tokenizer
.
add_special_tokens
(
{
"eos_token"
:
DEFAULT_EOS_TOKEN
,
"bos_token"
:
DEFAULT_BOS_TOKEN
,
"unk_token"
:
DEFAULT_UNK_TOKEN
,
}
)
return
tokenizer
def
smart_tokenizer_and_embedding_resize
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
model
:
transformers
.
PreTrainedModel
,
special_tokens_dict
:
Dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
if
tokenizer
.
pad_token
is
None
:
num_new_tokens
=
tokenizer
.
add_special_tokens
(
special_tokens_dict
)
model
.
resize_token_embeddings
(
len
(
tokenizer
))
if
num_new_tokens
>
0
:
input_embeddings
=
model
.
get_input_embeddings
().
weight
.
data
output_embeddings
=
model
.
get_output_embeddings
().
weight
.
data
input_embeddings_avg
=
input_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
output_embeddings_avg
=
output_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
input_embeddings
[
-
num_new_tokens
:]
=
input_embeddings_avg
output_embeddings
[
-
num_new_tokens
:]
=
output_embeddings_avg
\ No newline at end of file
applications/ChatGPT/examples/train_sft.py
View file @
4fd4bd9d
...
...
@@ -4,15 +4,18 @@ import loralib as lora
import
torch
import
torch.distributed
as
dist
from
torch.utils.data.distributed
import
DistributedSampler
from
chatgpt.dataset
import
SFTDataset
from
chatgpt.dataset
import
SFTDataset
,
AlpacaDataset
,
AlpacaDataCollator
from
chatgpt.models.base
import
RewardModel
from
chatgpt.models.bloom
import
BLOOMLM
from
chatgpt.models.gpt
import
GPTLM
from
chatgpt.models.opt
import
OPTLM
from
chatgpt.models.llama
import
LlamaLM
from
chatgpt.trainer
import
SFTTrainer
from
chatgpt.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
chatgpt.utils
import
prepare_llama_tokenizer_and_embedding
from
datasets
import
load_dataset
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
...
...
@@ -41,6 +44,8 @@ def train(args):
model
=
OPTLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
elif
args
.
model
==
'gpt2'
:
model
=
GPTLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
elif
args
.
model
==
'llama'
:
model
=
LlamaLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
...
...
@@ -53,9 +58,19 @@ def train(args):
tokenizer
.
pad_token
=
tokenizer
.
eos_token
elif
args
.
model
==
'opt'
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-350m"
)
elif
args
.
model
==
'llama'
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrain
,
padding_side
=
"right"
,
use_fast
=
False
,
)
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
if
args
.
model
==
'llama'
:
tokenizer
=
prepare_llama_tokenizer_and_embedding
(
tokenizer
,
model
)
else
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
max_len
=
512
...
...
@@ -67,11 +82,19 @@ def train(args):
logger
=
get_dist_logger
()
train_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'train'
)
eval_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'test'
)
# configure dataset
if
args
.
dataset
==
'yizhongw/self_instruct'
:
train_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'train'
)
eval_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'test'
)
train_dataset
=
SFTDataset
(
train_data
,
tokenizer
,
max_len
)
eval_dataset
=
SFTDataset
(
eval_data
,
tokenizer
,
max_len
)
train_dataset
=
SFTDataset
(
train_data
,
tokenizer
,
max_len
)
eval_dataset
=
SFTDataset
(
eval_data
,
tokenizer
,
max_len
)
elif
'alpaca'
in
args
.
dataset
:
train_dataset
=
AlpacaDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
dataset
)
eval_dataset
=
None
eval_dataset
data_collator
=
AlpacaDataCollator
(
tokenizer
=
tokenizer
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
sampler
=
DistributedSampler
(
train_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
...
...
@@ -79,11 +102,15 @@ def train(args):
else
:
sampler
=
None
train_dataloader
=
DataLoader
(
train_dataset
,
shuffle
=
(
sampler
is
None
),
sampler
=
sampler
,
batch_size
=
args
.
batch_size
)
if
eval_dataset
is
not
None
:
eval_dataloader
=
DataLoader
(
eval_dataset
,
batch_size
=
args
.
batch_size
)
trainer
=
SFTTrainer
(
model
=
model
,
strategy
=
strategy
,
optim
=
optim
,
train_data
set
=
train_data
set
,
eval_data
set
=
eval_data
set
,
train_data
loader
=
train_data
loader
,
eval_data
loader
=
eval_data
loader
,
sampler
=
sampler
,
batch_size
=
args
.
batch_size
,
max_epochs
=
args
.
max_epochs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment