Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
7788e0b0
Unverified
Commit
7788e0b0
authored
Apr 17, 2023
by
tingfeng cao
Committed by
GitHub
Apr 17, 2023
Browse files
fix: fix sft (#3568)
parent
6e7e43c6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
15 deletions
+12
-15
applications/Chat/coati/dataset/sft_dataset.py
applications/Chat/coati/dataset/sft_dataset.py
+4
-8
applications/Chat/coati/trainer/sft.py
applications/Chat/coati/trainer/sft.py
+8
-7
No files found.
applications/Chat/coati/dataset/sft_dataset.py
View file @
7788e0b0
...
...
@@ -53,29 +53,25 @@ class SFTDataset(Dataset):
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
=
512
)
->
None
:
super
().
__init__
()
# self.prompts = []
self
.
input_ids
=
[]
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
prompt
=
data
[
'prompt'
]
+
data
[
'completion'
]
+
"<|endoftext|>"
prompt
=
data
[
'prompt'
]
+
data
[
'completion'
]
+
tokenizer
.
eos_token
prompt_token
=
tokenizer
(
prompt
,
max_length
=
max_length
,
padding
=
"max_length"
,
truncation
=
True
,
return_tensors
=
"pt"
)
# self.prompts.append(prompt_token)s
self
.
input_ids
.
append
(
prompt_token
)
self
.
input_ids
.
append
(
prompt_token
[
'input_ids'
][
0
])
self
.
labels
=
copy
.
deepcopy
(
self
.
input_ids
)
def
__len__
(
self
):
length
=
len
(
self
.
prompt
s
)
length
=
len
(
self
.
input_id
s
)
return
length
def
__getitem__
(
self
,
idx
):
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
# return dict(self.prompts[idx], self.prompts[idx])
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_length
:
int
)
->
Dict
:
...
...
applications/Chat/coati/trainer/sft.py
View file @
7788e0b0
...
...
@@ -96,7 +96,7 @@ class SFTTrainer(ABC):
loss
=
outputs
.
loss
prompt_logits
=
outputs
.
logits
if
loss
>=
2.5
:
if
loss
>=
2.5
and
is_rank_0
()
:
logger
.
warning
(
f
"batch_id:
{
batch_id
}
, abnormal loss:
{
loss
}
"
)
loss
=
loss
/
self
.
accimulation_steps
...
...
@@ -110,6 +110,7 @@ class SFTTrainer(ABC):
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
optimizer
.
zero_grad
()
self
.
scheduler
.
step
()
if
is_rank_0
():
wandb
.
log
({
"loss"
:
total_loss
/
self
.
accimulation_steps
,
"lr"
:
self
.
scheduler
.
get_last_lr
()[
0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment