Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
fa97a9ca
Unverified
Commit
fa97a9ca
authored
Mar 23, 2023
by
Fazzie-Maqianli
Committed by
GitHub
Mar 23, 2023
Browse files
[chatgpt] unnify datasets (#3218)
parent
4fd4bd9d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
5 deletions
+12
-5
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
+8
-3
applications/ChatGPT/chatgpt/trainer/sft.py
applications/ChatGPT/chatgpt/trainer/sft.py
+4
-2
No files found.
applications/ChatGPT/chatgpt/dataset/sft_dataset.py
View file @
fa97a9ca
...
@@ -54,7 +54,8 @@ class SFTDataset(Dataset):
...
@@ -54,7 +54,8 @@ class SFTDataset(Dataset):
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
=
512
)
->
None
:
def
__init__
(
self
,
dataset
,
tokenizer
:
Callable
,
max_length
:
int
=
512
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
prompts
=
[]
# self.prompts = []
self
.
input_ids
=
[]
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
for
data
in
tqdm
(
dataset
,
disable
=
not
is_rank_0
()):
prompt
=
data
[
'prompt'
]
+
data
[
'completion'
]
+
"<|endoftext|>"
prompt
=
data
[
'prompt'
]
+
data
[
'completion'
]
+
"<|endoftext|>"
...
@@ -64,14 +65,18 @@ class SFTDataset(Dataset):
...
@@ -64,14 +65,18 @@ class SFTDataset(Dataset):
truncation
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
)
return_tensors
=
"pt"
)
self
.
prompts
.
append
(
prompt_token
)
# self.prompts.append(prompt_token)s
self
.
input_ids
.
append
(
prompt_token
)
self
.
labels
=
copy
.
deepcopy
(
self
.
input_ids
)
def
__len__
(
self
):
def
__len__
(
self
):
length
=
len
(
self
.
prompts
)
length
=
len
(
self
.
prompts
)
return
length
return
length
def
__getitem__
(
self
,
idx
):
def
__getitem__
(
self
,
idx
):
return
self
.
prompts
[
idx
]
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
return
dict
(
input_ids
=
self
.
input_ids
[
i
],
labels
=
self
.
labels
[
i
])
# return dict(self.prompts[idx], self.prompts[idx])
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
)
->
Dict
:
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
)
->
Dict
:
...
...
applications/ChatGPT/chatgpt/trainer/sft.py
View file @
fa97a9ca
...
@@ -63,11 +63,13 @@ class SFTTrainer(ABC):
...
@@ -63,11 +63,13 @@ class SFTTrainer(ABC):
for
batch_id
,
batch
in
enumerate
(
self
.
train_dataloader
):
for
batch_id
,
batch
in
enumerate
(
self
.
train_dataloader
):
prompt_ids
=
batch
[
"input_ids"
]
prompt_ids
=
batch
[
"input_ids"
]
p_mask
=
batch
[
"attention_mask"
]
p_mask
=
batch
[
"attention_mask"
]
labels
=
batch
[
"labels"
]
prompt_ids
=
prompt_ids
.
squeeze
(
1
).
cuda
()
prompt_ids
=
prompt_ids
.
squeeze
(
1
).
cuda
()
p_mask
=
p_mask
.
squeeze
(
1
).
cuda
()
p_mask
=
p_mask
.
squeeze
(
1
).
cuda
()
prompt_logits
=
self
.
model
(
prompt_ids
,
attention_mask
=
p_mask
)
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
loss
,
prompt_logits
=
self
.
model
(
prompt_ids
,
attention_mask
=
p_mask
,
labels
=
labels
)
loss
=
self
.
loss_fn
(
prompt_logits
,
prompt_id
s
)
#
loss = self.loss_fn(prompt_logits,
label
s)
self
.
strategy
.
backward
(
loss
,
self
.
model
,
self
.
optimizer
)
self
.
strategy
.
backward
(
loss
,
self
.
model
,
self
.
optimizer
)
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
strategy
.
optimizer_step
(
self
.
optimizer
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment