Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4fd4bd9d
Unverified
Commit
4fd4bd9d
authored
Mar 23, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Mar 23, 2023
Browse files
[chatgpt] support instuct training (#3216)
parent
cd142fbe
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
313 additions
and
39 deletions
+313
-39
applications/ChatGPT/chatgpt/dataset/__init__.py
applications/ChatGPT/chatgpt/dataset/__init__.py
+2
-2
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
+120
-2
applications/ChatGPT/chatgpt/dataset/utils.py
applications/ChatGPT/chatgpt/dataset/utils.py
+15
-0
applications/ChatGPT/chatgpt/models/llama/__init__.py
applications/ChatGPT/chatgpt/models/llama/__init__.py
+2
-1
applications/ChatGPT/chatgpt/models/llama/llama_lm.py
applications/ChatGPT/chatgpt/models/llama/llama_lm.py
+38
-0
applications/ChatGPT/chatgpt/trainer/sft.py
applications/ChatGPT/chatgpt/trainer/sft.py
+24
-26
applications/ChatGPT/chatgpt/utils/__init__.py
applications/ChatGPT/chatgpt/utils/__init__.py
+3
-0
applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
+74
-0
applications/ChatGPT/examples/train_sft.py
applications/ChatGPT/examples/train_sft.py
+35
-8
No files found.
applications/ChatGPT/chatgpt/dataset/__init__.py
View file @
4fd4bd9d
from
.reward_dataset
import
RmStaticDataset
,
HhRlhfDataset
from
.reward_dataset
import
RmStaticDataset
,
HhRlhfDataset
from
.utils
import
is_rank_0
from
.utils
import
is_rank_0
from
.sft_dataset
import
SFTDataset
from
.sft_dataset
import
SFTDataset
,
AlpacaDataset
,
AlpacaDataCollator
__all__
=
[
'RmStaticDataset'
,
'HhRlhfDataset'
,
'is_rank_0'
,
'SFTDataset'
]
__all__
=
[
'RmStaticDataset'
,
'HhRlhfDataset'
,
'is_rank_0'
,
'SFTDataset'
,
'AlpacaDataset'
,
'AlpacaDataCollator'
]
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
View file @
4fd4bd9d
from
typing
import
Callable
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
Sequence
import
random
import
random
from
torch.utils.data
import
Dataset
from
torch.utils.data
import
Dataset
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
tqdm
import
tqdm
from
tqdm
import
tqdm
import
torch
import
torch
from
.utils
import
is_rank_0
from
.utils
import
is_rank_0
,
jload
import
transformers
from
colossalai.logging
import
get_dist_logger
logger
=
get_dist_logger
()
IGNORE_INDEX
=
-
100
PROMPT_DICT
=
{
"prompt_input"
:
(
"Below is an instruction that describes a task, paired with an input that provides further context. "
"Write a response that appropriately completes the request.
\n\n
"
"### Instruction:
\n
{instruction}
\n\n
### Input:
\n
{input}
\n\n
### Response:"
),
"prompt_no_input"
:
(
"Below is an instruction that describes a task. "
"Write a response that appropriately completes the request.
\n\n
"
"### Instruction:
\n
{instruction}
\n\n
### Response:"
),
}
class
SFTDataset
(
Dataset
):
class
SFTDataset
(
Dataset
):
"""
"""
...
@@ -38,3 +72,87 @@ class SFTDataset(Dataset):
...
@@ -38,3 +72,87 @@ class SFTDataset(Dataset):
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
self
.
prompts
[
idx
]
return
self
.
prompts
[
idx
]
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
)
->
Dict
:
"""Tokenize a list of strings."""
tokenized_list
=
[
tokenizer
(
text
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
max_length
=
tokenizer
.
model_max_length
,
truncation
=
True
,
)
for
text
in
strings
]
input_ids
=
labels
=
[
tokenized
.
input_ids
[
0
]
for
tokenized
in
tokenized_list
]
input_ids_lens
=
labels_lens
=
[
tokenized
.
input_ids
.
ne
(
tokenizer
.
pad_token_id
).
sum
().
item
()
for
tokenized
in
tokenized_list
]
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
,
input_ids_lens
=
input_ids_lens
,
labels_lens
=
labels_lens
,
)
def
preprocess
(
sources
:
Sequence
[
str
],
targets
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
)
->
Dict
:
"""Preprocess the data by tokenizing."""
examples
=
[
s
+
t
for
s
,
t
in
zip
(
sources
,
targets
)]
examples_tokenized
,
sources_tokenized
=
[
_tokenize_fn
(
strings
,
tokenizer
)
for
strings
in
(
examples
,
sources
)]
input_ids
=
examples_tokenized
[
"input_ids"
]
labels
=
copy
.
deepcopy
(
input_ids
)
for
label
,
source_len
in
zip
(
labels
,
sources_tokenized
[
"input_ids_lens"
]):
label
[:
source_len
]
=
IGNORE_INDEX
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
)
class
AlpacaDataset
(
Dataset
):
"""Dataset for supervised fine-tuning."""
def
__init__
(
self
,
data_path
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
):
super
(
AlpacaDataset
,
self
).
__init__
()
logger
.
info
(
"Loading data..."
)
list_data_dict
=
jload
(
data_path
)
logger
.
info
(
"Formatting inputs..."
)
prompt_input
,
prompt_no_input
=
PROMPT_DICT
[
"prompt_input"
],
PROMPT_DICT
[
"prompt_no_input"
]
sources
=
[
prompt_input
.
format_map
(
example
)
if
example
.
get
(
"input"
,
""
)
!=
""
else
prompt_no_input
.
format_map
(
example
)
for
example
in
list_data_dict
]
targets
=
[
f
"
{
example
[
'output'
]
}{
tokenizer
.
eos_token
}
"
for
example
in
list_data_dict
]
logger
.
info
(
"Tokenizing inputs... This may take some time..."
)
data_dict
=
preprocess
(
sources
,
targets
,
tokenizer
)
self
.
input_ids
=
data_dict
[
"input_ids"
]
self
.
labels
=
data_dict
[
"labels"
]
def
__len__
(
self
):
return
len
(
self
.
input_ids
)
def
__getitem__
(
self
,
i
)
->
Dict
[
str
,
torch
.
Tensor
]:
return
dict
(
input_ids
=
self
.
input_ids
[
i
],
labels
=
self
.
labels
[
i
])
@
dataclass
class
AlpacaDataCollator
(
object
):
"""Collate examples for supervised fine-tuning."""
tokenizer
:
transformers
.
PreTrainedTokenizer
def
__call__
(
self
,
instances
:
Sequence
[
Dict
])
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
,
labels
=
tuple
([
instance
[
key
]
for
instance
in
instances
]
for
key
in
(
"input_ids"
,
"labels"
))
input_ids
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
input_ids
,
batch_first
=
True
,
padding_value
=
self
.
tokenizer
.
pad_token_id
)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=
IGNORE_INDEX
)
return
dict
(
input_ids
=
input_ids
,
labels
=
labels
,
attention_mask
=
input_ids
.
ne
(
self
.
tokenizer
.
pad_token_id
),
)
applications/ChatGPT/chatgpt/dataset/utils.py
View file @
4fd4bd9d
import
io
import
json
import
torch.distributed
as
dist
import
torch.distributed
as
dist
def
is_rank_0
()
->
bool
:
def
is_rank_0
()
->
bool
:
return
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
return
not
dist
.
is_initialized
()
or
dist
.
get_rank
()
==
0
def
_make_r_io_base
(
f
,
mode
:
str
):
if
not
isinstance
(
f
,
io
.
IOBase
):
f
=
open
(
f
,
mode
=
mode
)
return
f
def
jload
(
f
,
mode
=
"r"
):
"""Load a .json file into a dictionary."""
f
=
_make_r_io_base
(
f
,
mode
)
jdict
=
json
.
load
(
f
)
f
.
close
()
return
jdict
\ No newline at end of file
applications/ChatGPT/chatgpt/models/llama/__init__.py
View file @
4fd4bd9d
from
.llama_actor
import
LlamaActor
from
.llama_actor
import
LlamaActor
from
.llama_critic
import
LlamaCritic
from
.llama_critic
import
LlamaCritic
from
.llama_rm
import
LlamaRM
from
.llama_rm
import
LlamaRM
from
.llama_lm
import
LlamaLM
__all__
=
[
'LlamaActor'
,
'LlamaCritic'
,
'LlamaRM'
]
__all__
=
[
'LlamaActor'
,
'LlamaCritic'
,
'LlamaRM'
,
'LlamaLM'
]
applications/ChatGPT/chatgpt/models/llama/llama_lm.py
0 → 100644
View file @
4fd4bd9d
from
typing
import
Optional
from
transformers
import
LlamaConfig
,
LlamaForCausalLM
from
..base
import
LM
class
LlamaLM
(
LM
):
"""
Llama language model.
Args:
pretrained (str): Pretrained model name or path.
config (LlamaConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
lora_rank (int): LoRA rank.
lora_train_bias (str): LoRA bias training mode.
"""
def
__init__
(
self
,
pretrained
:
Optional
[
str
]
=
None
,
config
:
Optional
[
LlamaConfig
]
=
None
,
checkpoint
:
bool
=
False
,
lora_rank
:
int
=
0
,
lora_train_bias
:
str
=
'none'
)
->
None
:
if
pretrained
is
not
None
:
model
=
LlamaForCausalLM
.
from_pretrained
(
pretrained
)
elif
config
is
not
None
:
model
=
LlamaForCausalLM
(
config
)
else
:
model
=
LlamaForCausalLM
(
LlamaConfig
())
if
checkpoint
:
model
.
gradient_checkpointing_enable
()
super
().
__init__
(
model
,
lora_rank
,
lora_train_bias
)
applications/ChatGPT/chatgpt/trainer/sft.py
View file @
4fd4bd9d
...
@@ -2,7 +2,6 @@ from abc import ABC
...
@@ -2,7 +2,6 @@ from abc import ABC
from
typing
import
Optional
from
typing
import
Optional
import
loralib
as
lora
import
loralib
as
lora
import
torch
import
torch
from
chatgpt.dataset
import
SFTDataset
from
chatgpt.models.loss
import
GPTLMLoss
from
chatgpt.models.loss
import
GPTLMLoss
from
torch.optim
import
Adam
,
Optimizer
from
torch.optim
import
Adam
,
Optimizer
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
...
@@ -22,8 +21,8 @@ class SFTTrainer(ABC):
...
@@ -22,8 +21,8 @@ class SFTTrainer(ABC):
model (torch.nn.Module): the model to train
model (torch.nn.Module): the model to train
strategy (Strategy): the strategy to use for training
strategy (Strategy): the strategy to use for training
optim(Optimizer): the optimizer to use for training
optim(Optimizer): the optimizer to use for training
train_data
set (SFTDataset or SFTDistributedDataset)
: the data
set
to use for training
train_data
loader
: the data
loader
to use for training
eval_data
set (SFTDataset or SFTDistributedDataset)
: the data
set
to use for evaluation
eval_data
loader
: the data
loader
to use for evaluation
batch_size (int, defaults to 1): the batch size while training
batch_size (int, defaults to 1): the batch size while training
max_epochs (int, defaults to 2): the number of epochs to train
max_epochs (int, defaults to 2): the number of epochs to train
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
...
@@ -34,8 +33,8 @@ class SFTTrainer(ABC):
...
@@ -34,8 +33,8 @@ class SFTTrainer(ABC):
model
,
model
,
strategy
:
Strategy
,
strategy
:
Strategy
,
optim
:
Optimizer
,
optim
:
Optimizer
,
train_data
set
:
SFTDataset
,
train_data
loader
:
DataLoader
,
eval_data
set
:
SFTDataset
,
eval_data
loader
:
DataLoader
=
None
,
sampler
:
Optional
[
DistributedSampler
]
=
None
,
sampler
:
Optional
[
DistributedSampler
]
=
None
,
batch_size
:
int
=
1
,
batch_size
:
int
=
1
,
max_epochs
:
int
=
2
,
max_epochs
:
int
=
2
,
...
@@ -43,13 +42,10 @@ class SFTTrainer(ABC):
...
@@ -43,13 +42,10 @@ class SFTTrainer(ABC):
super
().
__init__
()
super
().
__init__
()
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
epochs
=
max_epochs
self
.
epochs
=
max_epochs
self
.
train_dataset
=
train_dataset
self
.
eval_dataset
=
eval_dataset
self
.
sampler
=
sampler
self
.
sampler
=
sampler
self
.
train_dataloader
=
DataLoader
(
self
.
train_dataset
,
shuffle
=
(
sampler
is
None
),
self
.
train_dataloader
=
train_dataloader
sampler
=
sampler
,
batch_size
=
batch_size
)
self
.
eval_dataloader
=
eval_dataloader
self
.
eval_dataloader
=
DataLoader
(
self
.
eval_dataset
,
batch_size
=
batch_size
)
self
.
model
=
strategy
.
setup_model
(
model
)
self
.
model
=
strategy
.
setup_model
(
model
)
if
"DDP"
in
str
(
self
.
strategy
):
if
"DDP"
in
str
(
self
.
strategy
):
...
@@ -79,23 +75,25 @@ class SFTTrainer(ABC):
...
@@ -79,23 +75,25 @@ class SFTTrainer(ABC):
logger
.
info
(
f
'Train Epoch
{
epoch
}
/
{
self
.
epochs
}
Batch
{
batch_id
}
Rank
{
dist
.
get_rank
()
}
loss
{
loss
.
item
()
}
'
)
logger
.
info
(
f
'Train Epoch
{
epoch
}
/
{
self
.
epochs
}
Batch
{
batch_id
}
Rank
{
dist
.
get_rank
()
}
loss
{
loss
.
item
()
}
'
)
# eval
# eval
self
.
model
.
eval
()
if
self
.
eval_dataloader
is
not
None
:
with
torch
.
no_grad
():
self
.
model
.
eval
()
loss_sum
=
0
with
torch
.
no_grad
():
num_seen
=
0
loss_sum
=
0
for
batch
in
self
.
eval_dataloader
:
num_seen
=
0
prompt_ids
=
batch
[
"input_ids"
]
for
batch
in
self
.
eval_dataloader
:
p_mask
=
batch
[
"attention_mask"
]
prompt_ids
=
batch
[
"input_ids"
]
prompt_ids
=
prompt_ids
.
squeeze
(
1
).
cuda
()
p_mask
=
batch
[
"attention_mask"
]
p_mask
=
p_mask
.
squeeze
(
1
).
cuda
()
prompt_ids
=
prompt_ids
.
squeeze
(
1
).
cuda
()
p_mask
=
p_mask
.
squeeze
(
1
).
cuda
()
prompt_logits
=
self
.
model
(
prompt_ids
,
attention_mask
=
p_mask
)
prompt_logits
=
self
.
model
(
prompt_ids
,
attention_mask
=
p_mask
)
loss
=
self
.
loss_fn
(
prompt_logits
,
prompt_ids
)
loss
=
self
.
loss_fn
(
prompt_logits
,
prompt_ids
)
loss_sum
+=
loss
.
item
()
loss_sum
+=
loss
.
item
()
num_seen
+=
prompt_ids
.
size
(
0
)
num_seen
+=
prompt_ids
.
size
(
0
)
loss_mean
=
loss_sum
/
num_seen
loss_mean
=
loss_sum
/
num_seen
if
dist
.
get_rank
()
==
0
:
if
dist
.
get_rank
()
==
0
:
logger
.
info
(
f
'Eval Epoch
{
epoch
}
/
{
self
.
epochs
}
loss
{
loss_mean
}
'
)
logger
.
info
(
f
'Eval Epoch
{
epoch
}
/
{
self
.
epochs
}
loss
{
loss_mean
}
'
)
epoch_bar
.
update
()
epoch_bar
.
update
()
applications/ChatGPT/chatgpt/utils/__init__.py
0 → 100644
View file @
4fd4bd9d
from
.tokenizer_utils
import
smart_tokenizer_and_embedding_resize
,
prepare_llama_tokenizer_and_embedding
__all__
=
[
'smart_tokenizer_and_embedding_resize'
,
'prepare_llama_tokenizer_and_embedding'
]
\ No newline at end of file
applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
0 → 100644
View file @
4fd4bd9d
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
typing
import
Dict
import
transformers
DEFAULT_PAD_TOKEN
=
"[PAD]"
DEFAULT_EOS_TOKEN
=
"</s>"
DEFAULT_BOS_TOKEN
=
"</s>"
DEFAULT_UNK_TOKEN
=
"</s>"
def
prepare_llama_tokenizer_and_embedding
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
model
:
transformers
.
PreTrainedModel
,
special_tokens_dict
:
Dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
):
"""prepare llama tokenizer and embedding.
"""
if
tokenizer
.
pad_token
is
None
:
smart_tokenizer_and_embedding_resize
(
special_tokens_dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
tokenizer
=
tokenizer
,
model
=
model
,
)
tokenizer
.
add_special_tokens
(
{
"eos_token"
:
DEFAULT_EOS_TOKEN
,
"bos_token"
:
DEFAULT_BOS_TOKEN
,
"unk_token"
:
DEFAULT_UNK_TOKEN
,
}
)
return
tokenizer
def
smart_tokenizer_and_embedding_resize
(
tokenizer
:
transformers
.
PreTrainedTokenizer
,
model
:
transformers
.
PreTrainedModel
,
special_tokens_dict
:
Dict
=
dict
(
pad_token
=
DEFAULT_PAD_TOKEN
),
):
"""Resize tokenizer and embedding.
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
"""
if
tokenizer
.
pad_token
is
None
:
num_new_tokens
=
tokenizer
.
add_special_tokens
(
special_tokens_dict
)
model
.
resize_token_embeddings
(
len
(
tokenizer
))
if
num_new_tokens
>
0
:
input_embeddings
=
model
.
get_input_embeddings
().
weight
.
data
output_embeddings
=
model
.
get_output_embeddings
().
weight
.
data
input_embeddings_avg
=
input_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
output_embeddings_avg
=
output_embeddings
[:
-
num_new_tokens
].
mean
(
dim
=
0
,
keepdim
=
True
)
input_embeddings
[
-
num_new_tokens
:]
=
input_embeddings_avg
output_embeddings
[
-
num_new_tokens
:]
=
output_embeddings_avg
\ No newline at end of file
applications/ChatGPT/examples/train_sft.py
View file @
4fd4bd9d
...
@@ -4,15 +4,18 @@ import loralib as lora
...
@@ -4,15 +4,18 @@ import loralib as lora
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
from
chatgpt.dataset
import
SFTDataset
from
chatgpt.dataset
import
SFTDataset
,
AlpacaDataset
,
AlpacaDataCollator
from
chatgpt.models.base
import
RewardModel
from
chatgpt.models.base
import
RewardModel
from
chatgpt.models.bloom
import
BLOOMLM
from
chatgpt.models.bloom
import
BLOOMLM
from
chatgpt.models.gpt
import
GPTLM
from
chatgpt.models.gpt
import
GPTLM
from
chatgpt.models.opt
import
OPTLM
from
chatgpt.models.opt
import
OPTLM
from
chatgpt.models.llama
import
LlamaLM
from
chatgpt.trainer
import
SFTTrainer
from
chatgpt.trainer
import
SFTTrainer
from
chatgpt.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
chatgpt.trainer.strategies
import
ColossalAIStrategy
,
DDPStrategy
,
NaiveStrategy
from
chatgpt.utils
import
prepare_llama_tokenizer_and_embedding
from
datasets
import
load_dataset
from
datasets
import
load_dataset
from
torch.optim
import
Adam
from
torch.optim
import
Adam
from
torch.utils.data
import
DataLoader
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
from
transformers
import
AutoTokenizer
,
BloomTokenizerFast
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
from
transformers.models.gpt2.tokenization_gpt2
import
GPT2Tokenizer
...
@@ -41,6 +44,8 @@ def train(args):
...
@@ -41,6 +44,8 @@ def train(args):
model
=
OPTLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
model
=
OPTLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
elif
args
.
model
==
'gpt2'
:
elif
args
.
model
==
'gpt2'
:
model
=
GPTLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
model
=
GPTLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
elif
args
.
model
==
'llama'
:
model
=
LlamaLM
(
pretrained
=
args
.
pretrain
,
lora_rank
=
args
.
lora_rank
).
cuda
()
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
...
@@ -53,9 +58,19 @@ def train(args):
...
@@ -53,9 +58,19 @@ def train(args):
tokenizer
.
pad_token
=
tokenizer
.
eos_token
tokenizer
.
pad_token
=
tokenizer
.
eos_token
elif
args
.
model
==
'opt'
:
elif
args
.
model
==
'opt'
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-350m"
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"facebook/opt-350m"
)
elif
args
.
model
==
'llama'
:
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrain
,
padding_side
=
"right"
,
use_fast
=
False
,
)
else
:
else
:
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
raise
ValueError
(
f
'Unsupported model "
{
args
.
model
}
"'
)
tokenizer
.
pad_token
=
tokenizer
.
eos_token
if
args
.
model
==
'llama'
:
tokenizer
=
prepare_llama_tokenizer_and_embedding
(
tokenizer
,
model
)
else
:
tokenizer
.
pad_token
=
tokenizer
.
eos_token
max_len
=
512
max_len
=
512
...
@@ -67,11 +82,19 @@ def train(args):
...
@@ -67,11 +82,19 @@ def train(args):
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
train_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'train'
)
# configure dataset
eval_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'test'
)
if
args
.
dataset
==
'yizhongw/self_instruct'
:
train_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'train'
)
eval_data
=
load_dataset
(
args
.
dataset
,
'super_natural_instructions'
,
split
=
'test'
)
train_dataset
=
SFTDataset
(
train_data
,
tokenizer
,
max_len
)
train_dataset
=
SFTDataset
(
train_data
,
tokenizer
,
max_len
)
eval_dataset
=
SFTDataset
(
eval_data
,
tokenizer
,
max_len
)
eval_dataset
=
SFTDataset
(
eval_data
,
tokenizer
,
max_len
)
elif
'alpaca'
in
args
.
dataset
:
train_dataset
=
AlpacaDataset
(
tokenizer
=
tokenizer
,
data_path
=
args
.
dataset
)
eval_dataset
=
None
eval_dataset
data_collator
=
AlpacaDataCollator
(
tokenizer
=
tokenizer
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
sampler
=
DistributedSampler
(
train_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
sampler
=
DistributedSampler
(
train_dataset
,
shuffle
=
True
,
seed
=
42
,
drop_last
=
True
)
...
@@ -79,11 +102,15 @@ def train(args):
...
@@ -79,11 +102,15 @@ def train(args):
else
:
else
:
sampler
=
None
sampler
=
None
train_dataloader
=
DataLoader
(
train_dataset
,
shuffle
=
(
sampler
is
None
),
sampler
=
sampler
,
batch_size
=
args
.
batch_size
)
if
eval_dataset
is
not
None
:
eval_dataloader
=
DataLoader
(
eval_dataset
,
batch_size
=
args
.
batch_size
)
trainer
=
SFTTrainer
(
model
=
model
,
trainer
=
SFTTrainer
(
model
=
model
,
strategy
=
strategy
,
strategy
=
strategy
,
optim
=
optim
,
optim
=
optim
,
train_data
set
=
train_data
set
,
train_data
loader
=
train_data
loader
,
eval_data
set
=
eval_data
set
,
eval_data
loader
=
eval_data
loader
,
sampler
=
sampler
,
sampler
=
sampler
,
batch_size
=
args
.
batch_size
,
batch_size
=
args
.
batch_size
,
max_epochs
=
args
.
max_epochs
)
max_epochs
=
args
.
max_epochs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment