Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7788e0b0
Unverified
Commit
7788e0b0
authored
Apr 17, 2023
by
tingfeng cao
Committed by
GitHub
Apr 17, 2023
Browse files
fix: fix sft (#3568)
parent
6e7e43c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
13 additions
and
16 deletions
+13
-16
applications/Chat/coati/dataset/sft_dataset.py
applications/Chat/coati/dataset/sft_dataset.py
+4
-8
applications/Chat/coati/trainer/sft.py
applications/Chat/coati/trainer/sft.py
+8
-7
applications/Chat/examples/train_sft.py
applications/Chat/examples/train_sft.py
+1
-1
No files found.
applications/Chat/coati/dataset/sft_dataset.py
View file @
7788e0b0
...
@@ -53,29 +53,25 @@ class SFTDataset(Dataset):
...
@@ -53,29 +53,25 @@ class SFTDataset(Dataset):
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
=
512
)
->
None
:
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
=
512
)
->
None
:
super
().
__init__
()
super
().
__init__
()
# self.prompts = []
self
.
input_ids
=
[]
self
.
input_ids
=
[]
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
prompt
=
data
[
'prompt'
]
+
data
[
'completion'
]
+
"<|endoftext|>"
prompt
=
data
[
'prompt'
]
+
data
[
'completion'
]
+
tokenizer
.
eos_token
prompt_token
=
tokenizer
(
prompt
,
prompt_token
=
tokenizer
(
prompt
,
max_length
=
max_length
,
max_length
=
max_length
,
padding
=
"max_length"
,
padding
=
"max_length"
,
truncation
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
)
return_tensors
=
"pt"
)
# self.prompts.append(prompt_token)s
self
.
input_ids
.
append
(
prompt_token
[
'input_ids'
][
0
])
self
.
input_ids
.
append
(
prompt_token
)
self
.
labels
=
copy
.
deepcopy
(
self
.
input_ids
)
self
.
labels
=
copy
.
deepcopy
(
self
.
input_ids
)
def
__len__
(
self
):
def
__len__
(
self
):
length
=
len
(
self
.
prompt
s
)
length
=
len
(
self
.
input_id
s
)
return
length
return
length
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
# return dict(self.prompts[idx], self.prompts[idx])
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_length
:
int
)
->
Dict
:
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_length
:
int
)
->
Dict
:
...
...
applications/Chat/coati/trainer/sft.py
View file @
7788e0b0
...
@@ -96,7 +96,7 @@ class SFTTrainer(ABC):
...
@@ -96,7 +96,7 @@ class SFTTrainer(ABC):
loss
=
outputs
.
loss
loss
=
outputs
.
loss
prompt_logits
=
outputs
.
logits
prompt_logits
=
outputs
.
logits
if
loss
>=
2.5
:
if
loss
>=
2.5
and
is_rank_0
()
:
logger
.
warning
(
f
"batch_id:
{
batch_id
}
, abnormal loss:
{
loss
}
"
)
logger
.
warning
(
f
"batch_id:
{
batch_id
}
, abnormal loss:
{
loss
}
"
)
loss
=
loss
/
self
.
accimulation_steps
loss
=
loss
/
self
.
accimulation_steps
...
@@ -110,12 +110,13 @@ class SFTTrainer(ABC):
...
@@ -110,12 +110,13 @@ class SFTTrainer(ABC):
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
self
.
scheduler
.
step
()
self
.
scheduler
.
step
()
wandb
.
log
({
if
is_rank_0
():
"loss"
:
total_loss
/
self
.
accimulation_steps
,
wandb
.
log
({
"lr"
:
self
.
scheduler
.
get_last_lr
()[
0
],
"loss"
:
total_loss
/
self
.
accimulation_steps
,
"epoch"
:
epoch
,
"lr"
:
self
.
scheduler
.
get_last_lr
()[
0
],
"batch_id"
:
batch_id
"epoch"
:
epoch
,
})
"batch_id"
:
batch_id
})
total_loss
=
0
total_loss
=
0
step_bar
.
update
()
step_bar
.
update
()
...
...
applications/Chat/examples/train_sft.py
View file @
7788e0b0
...
@@ -111,7 +111,7 @@ def train(args):
...
@@ -111,7 +111,7 @@ def train(args):
max_datasets_size
=
args
.
max_datasets_size
,
max_datasets_size
=
args
.
max_datasets_size
,
max_length
=
max_len
)
max_length
=
max_len
)
eval_dataset
=
None
eval_dataset
=
None
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
)
data_collator
=
DataCollatorForSupervisedDataset
(
tokenizer
=
tokenizer
)
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
if
dist
.
is_initialized
()
and
dist
.
get_world_size
()
>
1
:
train_sampler
=
DistributedSampler
(
train_dataset
,
train_sampler
=
DistributedSampler
(
train_dataset
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment