Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6afeb120
Unverified
Commit
6afeb120
authored
Apr 06, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Apr 06, 2023
Browse files
add community example dictionary (#3465)
parent
80eba05b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
94 additions
and
92 deletions
+94
-92
applications/Chat/examples/community/README.md
applications/Chat/examples/community/README.md
+1
-0
applications/Chat/examples/community/peft/README.md
applications/Chat/examples/community/peft/README.md
+2
-2
applications/Chat/examples/community/peft/easy_dataset.py
applications/Chat/examples/community/peft/easy_dataset.py
+62
-64
applications/Chat/examples/community/peft/easy_models.py
applications/Chat/examples/community/peft/easy_models.py
+6
-7
applications/Chat/examples/community/peft/train_peft_prompts.py
...ations/Chat/examples/community/peft/train_peft_prompts.py
+9
-8
applications/Chat/examples/community/peft/train_peft_sft.py
applications/Chat/examples/community/peft/train_peft_sft.py
+14
-11
No files found.
applications/Chat/examples/community/README.md
0 → 100644
View file @
6afeb120
# Community Examples
applications/Chat/examples/community/
EasyPeftModel
.md
→
applications/Chat/examples/community/
peft/README
.md
View file @
6afeb120
applications/Chat/examples/community/easy_dataset.py
→
applications/Chat/examples/community/
peft/
easy_dataset.py
View file @
6afeb120
import
copy
import
json
from
typing
import
Dict
,
Sequence
import
torch
from
datasets
import
load_dataset
from
torch.utils.data
import
Dataset
from
transformers
import
AutoTokenizer
import
torch
from
tqdm
import
tqdm
import
json
from
tqdm
import
tqdm
import
json
from
transformers
import
AutoTokenizer
IGNORE_INDEX
=
-
100
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
512
)
->
Dict
:
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
512
)
->
Dict
:
"""Tokenize a list of strings."""
tokenized_list
=
[
tokenizer
(
...
...
@@ -36,15 +34,12 @@ def _tokenize_fn(strings: Sequence[str], tokenizer: AutoTokenizer,max_length :in
)
def
preprocess
(
sources
:
Sequence
[
str
],
targets
:
Sequence
[
str
],
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
512
)
->
Dict
:
def
preprocess
(
sources
:
Sequence
[
str
],
targets
:
Sequence
[
str
],
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
512
)
->
Dict
:
"""Preprocess the data by tokenizing."""
examples
=
[
s
+
t
for
s
,
t
in
zip
(
sources
,
targets
)]
examples_tokenized
,
sources_tokenized
=
[
_tokenize_fn
(
strings
,
tokenizer
,
max_length
)
for
strings
in
(
examples
,
sources
)]
examples_tokenized
,
sources_tokenized
=
[
_tokenize_fn
(
strings
,
tokenizer
,
max_length
)
for
strings
in
(
examples
,
sources
)
]
input_ids
=
examples_tokenized
[
"input_ids"
]
labels
=
copy
.
deepcopy
(
input_ids
)
for
label
,
source_len
in
zip
(
labels
,
sources_tokenized
[
"input_ids_lens"
]):
...
...
@@ -53,21 +48,22 @@ def preprocess(
class
EasySupervisedDataset
(
Dataset
):
def
__init__
(
self
,
data_file
:
str
,
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
512
)
->
None
:
super
(
EasySupervisedDataset
,
self
).
__init__
()
with
open
(
data_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
def
__init__
(
self
,
data_file
:
str
,
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
512
)
->
None
:
super
(
EasySupervisedDataset
,
self
).
__init__
()
with
open
(
data_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
all_lines
=
f
.
readlines
()
#split to source and target ,source the characters before "回答:" including "回答:", target the characters after "回答:"
sources
,
targets
=
[],[]
sources
,
targets
=
[],
[]
for
line
in
all_lines
:
if
"回答:"
in
line
:
sep_index
=
line
.
index
(
"回答:"
)
sources
.
append
(
line
[:
sep_index
+
3
])
targets
.
append
(
line
[
sep_index
+
3
:]
+
tokenizer
.
eos_token
)
sources
.
append
(
line
[:
sep_index
+
3
])
targets
.
append
(
line
[
sep_index
+
3
:]
+
tokenizer
.
eos_token
)
else
:
sources
.
append
(
line
)
targets
.
append
(
""
+
tokenizer
.
eos_token
)
data_dict
=
preprocess
(
sources
,
targets
,
tokenizer
,
max_length
)
targets
.
append
(
""
+
tokenizer
.
eos_token
)
data_dict
=
preprocess
(
sources
,
targets
,
tokenizer
,
max_length
)
self
.
input_ids
=
data_dict
[
"input_ids"
]
self
.
labels
=
data_dict
[
"labels"
]
...
...
@@ -85,21 +81,21 @@ class EasySupervisedDataset(Dataset):
def
__str__
(
self
):
return
f
"LawSupervisedDataset(data_file=
{
self
.
data_file
}
, input_ids_len=
{
len
(
self
.
input_ids
)
}
, labels_len=
{
len
(
self
.
labels
)
}
)"
class
EasyPromptsDataset
(
Dataset
):
def
__init__
(
self
,
data_file
:
str
,
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
96
)
->
None
:
super
(
EasyPromptsDataset
,
self
).
__init__
()
with
open
(
data_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
def
__init__
(
self
,
data_file
:
str
,
tokenizer
:
AutoTokenizer
,
max_length
:
int
=
96
)
->
None
:
super
(
EasyPromptsDataset
,
self
).
__init__
()
with
open
(
data_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
all_lines
=
f
.
readlines
()
all_lines
=
[
line
if
"回答:"
not
in
line
else
line
[:
line
.
index
(
"回答:"
)
+
3
]
for
line
in
all_lines
]
all_lines
=
[
line
if
"回答:"
not
in
line
else
line
[:
line
.
index
(
"回答:"
)
+
3
]
for
line
in
all_lines
]
self
.
prompts
=
[
tokenizer
(
line
,
return_tensors
=
'pt'
,
max_length
=
max_length
,
padding
=
'max_length'
,
tokenizer
(
line
,
return_tensors
=
'pt'
,
max_length
=
max_length
,
padding
=
'max_length'
,
truncation
=
True
)[
'input_ids'
].
to
(
torch
.
cuda
.
current_device
()).
squeeze
(
0
)
for
line
in
tqdm
(
all_lines
)
]
self
.
data_file
=
data_file
def
__len__
(
self
):
return
len
(
self
.
prompts
)
...
...
@@ -114,8 +110,9 @@ class EasyPromptsDataset(Dataset):
class
EasyRewardDataset
(
Dataset
):
def
__init__
(
self
,
train_file
:
str
,
tokenizer
:
AutoTokenizer
,
special_token
=
None
,
max_length
=
512
)
->
None
:
super
(
EasyRewardDataset
,
self
).
__init__
()
def
__init__
(
self
,
train_file
:
str
,
tokenizer
:
AutoTokenizer
,
special_token
=
None
,
max_length
=
512
)
->
None
:
super
(
EasyRewardDataset
,
self
).
__init__
()
self
.
chosen
=
[]
self
.
reject
=
[]
if
special_token
is
None
:
...
...
@@ -124,11 +121,11 @@ class EasyRewardDataset(Dataset):
self
.
end_token
=
special_token
print
(
self
.
end_token
)
#read all lines in the train_file to a list
with
open
(
train_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
with
open
(
train_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
all_lines
=
f
.
readlines
()
for
line
in
tqdm
(
all_lines
):
data
=
json
.
loads
(
line
)
prompt
=
"提问:"
+
data
[
'prompt'
]
+
" 回答:"
prompt
=
"提问:"
+
data
[
'prompt'
]
+
" 回答:"
chosen
=
prompt
+
data
[
'chosen'
]
+
self
.
end_token
chosen_token
=
tokenizer
(
chosen
,
...
...
@@ -167,24 +164,27 @@ class EasyRewardDataset(Dataset):
def
__str__
(
self
):
return
f
"LawRewardDataset(chosen_len=
{
len
(
self
.
chosen
)
}
, reject_len=
{
len
(
self
.
reject
)
}
)"
'''
Easy SFT just accept a text file which can be read line by line. However the datasest will group texts together to max_length so LLM will learn the texts meaning better.
If individual lines are not related, just set is_group_texts to False.
'''
class
EasySFTDataset
(
Dataset
):
def
__init__
(
self
,
data_file
:
str
,
tokenizer
:
AutoTokenizer
,
max_length
=
512
,
is_group_texts
=
True
)
->
None
:
def
__init__
(
self
,
data_file
:
str
,
tokenizer
:
AutoTokenizer
,
max_length
=
512
,
is_group_texts
=
True
)
->
None
:
super
().
__init__
()
#read the data_file line by line
with
open
(
data_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
with
open
(
data_file
,
"r"
,
encoding
=
"UTF-8"
)
as
f
:
#encode the text data line by line and put raw python list input_ids only to raw_input_ids list
raw_input_ids
=
[]
for
line
in
f
:
encoded_ids
=
tokenizer
.
encode
(
line
)
#if the encoded_ids is longer than max_length, then split it into several parts
if
len
(
encoded_ids
)
>
max_length
:
for
i
in
range
(
0
,
len
(
encoded_ids
),
max_length
):
raw_input_ids
.
append
(
encoded_ids
[
i
:
i
+
max_length
])
for
i
in
range
(
0
,
len
(
encoded_ids
),
max_length
):
raw_input_ids
.
append
(
encoded_ids
[
i
:
i
+
max_length
])
else
:
raw_input_ids
.
append
(
encoded_ids
)
...
...
@@ -199,23 +199,26 @@ class EasySFTDataset(Dataset):
#pad the current_input_ids to max_length with tokenizer.pad_token_id
padded_length
=
max_length
-
len
(
current_input_ids
)
current_input_ids
.
extend
([
tokenizer
.
pad_token_id
]
*
padded_length
)
grouped_inpup_ids
.
append
(
torch
.
tensor
(
current_input_ids
,
dtype
=
torch
.
long
))
attention_mask
.
append
(
torch
.
tensor
([
1
]
*
(
max_length
-
padded_length
)
+
[
0
]
*
padded_length
,
dtype
=
torch
.
long
))
grouped_inpup_ids
.
append
(
torch
.
tensor
(
current_input_ids
,
dtype
=
torch
.
long
))
attention_mask
.
append
(
torch
.
tensor
([
1
]
*
(
max_length
-
padded_length
)
+
[
0
]
*
padded_length
,
dtype
=
torch
.
long
))
current_input_ids
=
[]
else
:
current_input_ids
.
extend
(
input_ids
)
if
len
(
current_input_ids
)
>
0
:
padded_length
=
max_length
-
len
(
current_input_ids
)
current_input_ids
.
extend
([
tokenizer
.
pad_token_id
]
*
padded_length
)
grouped_inpup_ids
.
append
(
torch
.
tensor
(
current_input_ids
,
dtype
=
torch
.
long
))
attention_mask
.
append
(
torch
.
tensor
([
1
]
*
(
max_length
-
padded_length
)
+
[
0
]
*
padded_length
,
dtype
=
torch
.
long
))
grouped_inpup_ids
.
append
(
torch
.
tensor
(
current_input_ids
,
dtype
=
torch
.
long
))
attention_mask
.
append
(
torch
.
tensor
([
1
]
*
(
max_length
-
padded_length
)
+
[
0
]
*
padded_length
,
dtype
=
torch
.
long
))
else
:
#just append the raw_input_ids to max_length
for
input_ids
in
raw_input_ids
:
padded_length
=
max_length
-
len
(
input_ids
)
input_ids
.
extend
([
tokenizer
.
pad_token_id
]
*
padded_length
)
attention_mask
.
append
(
torch
.
tensor
([
1
]
*
(
max_length
-
padded_length
)
+
[
0
]
*
padded_length
,
dtype
=
torch
.
long
))
grouped_inpup_ids
.
append
(
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
long
))
attention_mask
.
append
(
torch
.
tensor
([
1
]
*
(
max_length
-
padded_length
)
+
[
0
]
*
padded_length
,
dtype
=
torch
.
long
))
grouped_inpup_ids
.
append
(
torch
.
tensor
(
input_ids
,
dtype
=
torch
.
long
))
self
.
input_ids
=
grouped_inpup_ids
self
.
labels
=
copy
.
deepcopy
(
self
.
input_ids
)
self
.
file_name
=
data_file
...
...
@@ -225,8 +228,8 @@ class EasySFTDataset(Dataset):
return
len
(
self
.
input_ids
)
#get item from dataset
def
__getitem__
(
self
,
idx
):
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
],
attention_mask
=
self
.
attention_mask
[
idx
])
def
__getitem__
(
self
,
idx
):
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
],
attention_mask
=
self
.
attention_mask
[
idx
])
#generate the dataset description to be printed by print in python
def
__repr__
(
self
):
...
...
@@ -235,8 +238,3 @@ class EasySFTDataset(Dataset):
#generate the dataset description to be printed by print in python
def
__str__
(
self
):
return
f
"EasySFTDataset(len=
{
len
(
self
)
}
,
\n
file_name is
{
self
.
file_name
}
)"
\ No newline at end of file
applications/Chat/examples/community/easy_models.py
→
applications/Chat/examples/community/
peft/
easy_models.py
View file @
6afeb120
...
...
@@ -3,12 +3,12 @@ from typing import Optional, Tuple, Union
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.nn.modules
import
Module
from
coati.models.generation
import
generate
from
coati.models.utils
import
log_probs_from_logits
,
masked_mean
from
transformers
import
BloomConfig
,
BloomForCausalLM
from
coati.models.utils
import
log_probs_from_logits
,
masked_mean
from
peft
import
PeftModel
from
torch.nn.modules
import
Module
from
transformers
import
BloomConfig
,
BloomForCausalLM
class
Actor
(
Module
):
"""
...
...
@@ -87,11 +87,10 @@ class BLOOMActor(Actor):
else
:
model
=
BloomForCausalLM
(
BloomConfig
())
if
lora_path
is
not
None
:
model
=
PeftModel
.
from_pretrained
(
model
,
lora_path
)
model
=
PeftModel
.
from_pretrained
(
model
,
lora_path
)
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
)
def
print_trainable_parameters
(
self
):
self
.
get_base_model
().
print_trainable_parameters
()
applications/Chat/examples/community/train_peft_prompts.py
→
applications/Chat/examples/community/
peft/
train_peft_prompts.py
View file @
6afeb120
...
...
@@ -5,21 +5,22 @@ import torch
import
torch.distributed
as
dist
from
coati.dataset
import
DataCollatorForSupervisedDataset
,
PromptDataset
,
SupervisedDataset
from
coati.models.bloom
import
BLOOMRM
,
BLOOMCritic
from
easy_models
import
BLOOMActor
from
coati.models.gpt
import
GPTRM
,
GPTActor
,
GPTCritic
from
coati.models.llama
import
LlamaActor
,
LlamaCritic
,
LlamaRM
from
coati.models.opt
import
OPTRM
,
OPTActor
,
OPTCritic
from
coati.trainer
import
PPOTrainer
from
coati.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
coati.utils
import
prepare_llama_tokenizer_and_embedding
from
easy_dataset
import
EasyPromptsDataset
,
EasySupervisedDataset
from
easy_models
import
BLOOMActor
from
peft
import
PeftModel
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
torch.utils.data.distributed
import
DistributedSampler
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
GPT2Tokenizer
,
LlamaTokenizer
from
colossalai.nn.optimizer
import
HybridAdam
from
peft
import
PeftModel
from
easy_dataset
import
EasyPromptsDataset
,
EasySupervisedDataset
def
main
(
args
):
# configure strategy
...
...
@@ -41,7 +42,7 @@ def main(args):
if
args
.
model
==
'bloom'
:
# initial_model = BLOOMActor(pretrained=args.pretrain)
print
(
'Using peft lora to load Bloom model as inital_model'
)
initial_model
=
BLOOMActor
(
pretrained
=
args
.
pretrain
,
lora_path
=
args
.
sft_lora_path
)
initial_model
=
BLOOMActor
(
pretrained
=
args
.
pretrain
,
lora_path
=
args
.
sft_lora_path
)
print
(
'Using peft lora to load Bloom model as initial_model (Done)'
)
else
:
raise
ValueError
(
f
'Unsupported actor model "
{
args
.
model
}
"'
)
...
...
@@ -54,7 +55,7 @@ def main(args):
if
rm_model_name
==
'gpt2'
:
reward_model
=
GPTRM
(
pretrained
=
args
.
rm_pretrain
)
elif
rm_model_name
==
'bloom'
:
print
(
"load bloom reward model "
,
args
.
rm_pretrain
)
print
(
"load bloom reward model "
,
args
.
rm_pretrain
)
reward_model
=
BLOOMRM
(
pretrained
=
args
.
rm_pretrain
)
elif
rm_model_name
==
'opt'
:
reward_model
=
OPTRM
(
pretrained
=
args
.
rm_pretrain
)
...
...
@@ -75,7 +76,7 @@ def main(args):
if
args
.
model
==
'bloom'
:
# actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
print
(
'Using peft lora to load Bloom model as Actor'
)
actor
=
BLOOMActor
(
pretrained
=
args
.
pretrain
,
lora_path
=
args
.
sft_lora_path
)
actor
=
BLOOMActor
(
pretrained
=
args
.
pretrain
,
lora_path
=
args
.
sft_lora_path
)
print
(
'Using peft lora to load Bloom model as Actor (Done)'
)
else
:
raise
ValueError
(
f
'Unsupported actor model "
{
args
.
model
}
"'
)
...
...
@@ -83,7 +84,7 @@ def main(args):
if
rm_model_name
==
'gpt2'
:
critic
=
GPTCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
elif
rm_model_name
==
'bloom'
:
print
(
"load bloom critic "
,
args
.
rm_pretrain
,
" lora_rank "
,
args
.
lora_rank
,
" use_action_mask "
,
True
)
print
(
"load bloom critic "
,
args
.
rm_pretrain
,
" lora_rank "
,
args
.
lora_rank
,
" use_action_mask "
,
True
)
critic
=
BLOOMCritic
(
pretrained
=
args
.
rm_pretrain
,
lora_rank
=
args
.
lora_rank
,
use_action_mask
=
True
)
print
(
"load bloom critic (Done) "
)
elif
rm_model_name
==
'opt'
:
...
...
@@ -130,7 +131,7 @@ def main(args):
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
)
prompt_dataset
=
EasyPromptsDataset
(
args
.
prompt_path
,
tokenizer
)
prompt_dataset
=
EasyPromptsDataset
(
args
.
prompt_path
,
tokenizer
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
prompt_sampler
=
DistributedSampler
(
prompt_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
else
:
...
...
applications/Chat/examples/community/train_peft_sft.py
→
applications/Chat/examples/community/
peft/
train_peft_sft.py
View file @
6afeb120
...
...
@@ -14,19 +14,19 @@ from coati.trainer import SFTTrainer
from
coati.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
coati.utils
import
prepare_llama_tokenizer_and_embedding
from
datasets
import
load_dataset
from
easy_dataset
import
EasyDataset
from
peft
import
LoraConfig
,
PeftModel
,
TaskType
,
get_peft_model
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
torch.utils.data.dataloader
import
default_collate
from
torch.utils.data.distributed
import
DistributedSampler
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
,
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
BloomTokenizerFast
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.tensor
import
ColoParameter
from
torch.utils.data.dataloader
import
default_collate
from
peft
import
LoraConfig
,
TaskType
,
get_peft_model
,
PeftModel
from
easy_dataset
import
EasyDataset
def
train
(
args
):
# configure strategy
...
...
@@ -48,17 +48,20 @@ def train(args):
#if the args.save_path exists and args.save_path+'/adapter_config.json' exists, we'll load the adapter_config.json
if
os
.
path
.
exists
(
args
.
save_path
)
and
os
.
path
.
exists
(
args
.
save_path
+
'/adapter_config.json'
)
\
and
os
.
path
.
exists
(
args
.
save_path
+
'/adapter_model.bin'
):
print
(
"loading from saved peft model "
,
args
.
save_path
)
print
(
"loading from saved peft model "
,
args
.
save_path
)
model
=
PeftModel
.
from_pretrained
(
model
,
args
.
save_path
)
else
:
#we'll use peft lora library to do the lora
lora_rank
=
args
.
lora_rank
if
args
.
lora_rank
>
0
else
32
#config lora with rank of lora_rank
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
False
,
r
=
lora_rank
,
lora_alpha
=
32
,
lora_dropout
=
0.1
)
lora_config
=
LoraConfig
(
task_type
=
TaskType
.
CAUSAL_LM
,
inference_mode
=
False
,
r
=
lora_rank
,
lora_alpha
=
32
,
lora_dropout
=
0.1
)
model
=
get_peft_model
(
model
,
lora_config
)
model
.
print_trainable_parameters
()
# configure tokenizer
if
args
.
model
==
'gpt2'
:
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
'gpt2'
)
...
...
@@ -103,12 +106,12 @@ def train(args):
logger
.
set_level
(
'WARNING'
)
# configure dataset
law_dataset
=
EasyDataset
(
args
.
dataset
,
tokenizer
=
tokenizer
,
is_group_texts
=
not
args
.
is_short_text
)
law_dataset
=
EasyDataset
(
args
.
dataset
,
tokenizer
=
tokenizer
,
is_group_texts
=
not
args
.
is_short_text
)
train_dataset
=
law_dataset
print
(
train_dataset
)
eval_dataset
=
None
if
args
.
eval_dataset
is
not
None
:
eval_dataset
=
EasyDataset
(
args
.
eval_dataset
,
tokenizer
=
tokenizer
,
is_group_texts
=
not
args
.
is_short_text
)
eval_dataset
=
EasyDataset
(
args
.
eval_dataset
,
tokenizer
=
tokenizer
,
is_group_texts
=
not
args
.
is_short_text
)
data_collator
=
default_collate
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
train_sampler
=
DistributedSampler
(
train_dataset
,
...
...
@@ -181,7 +184,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--log_interval'
,
type
=
int
,
default
=
100
,
help
=
"how many steps to log"
)
parser
.
add_argument
(
'--lr'
,
type
=
float
,
default
=
5e-6
)
parser
.
add_argument
(
'--accimulation_steps'
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
'--enable_peft_lora'
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--is_short_text"
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
'--enable_peft_lora'
,
action
=
'store_true'
,
default
=
False
)
parser
.
add_argument
(
"--is_short_text"
,
action
=
'store_true'
,
default
=
False
)
args
=
parser
.
parse_args
()
train
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment