Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b2debdc0
Commit
b2debdc0
authored
Jul 18, 2023
by
Zheng Zangwei (Alex Zheng)
Committed by
binmakeswell
Jul 26, 2023
Browse files
[NFC] polish applications/Chat/coati/dataset/sft_dataset.py code style (#4259)
parent
abe4f971
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
11 deletions
+9
-11
applications/Chat/coati/dataset/sft_dataset.py
applications/Chat/coati/dataset/sft_dataset.py
+9
-11
No files found.
applications/Chat/coati/dataset/sft_dataset.py
View file @
b2debdc0
...
@@ -74,15 +74,10 @@ class SFTDataset(Dataset):
...
@@ -74,15 +74,10 @@ class SFTDataset(Dataset):
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
return
dict
(
input_ids
=
self
.
input_ids
[
idx
],
labels
=
self
.
labels
[
idx
])
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
def
_tokenize_fn
(
strings
:
Sequence
[
str
],
tokenizer
:
transformers
.
PreTrainedTokenizer
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_length
:
int
)
->
Dict
[
str
,
torch
.
Tensor
]:
max_length
:
int
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""Tokenize a list of strings."""
"""Tokenize a list of strings."""
tokenized_list
=
tokenizer
(
tokenized_list
=
tokenizer
(
strings
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
max_length
=
max_length
,
truncation
=
True
)
strings
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
max_length
=
max_length
,
truncation
=
True
)
input_ids
=
labels
=
tokenized_list
[
"input_ids"
]
input_ids
=
labels
=
tokenized_list
[
"input_ids"
]
input_ids_lens
=
labels_lens
=
\
input_ids_lens
=
labels_lens
=
\
tokenized_list
[
"input_ids"
].
ne
(
tokenizer
.
pad_token_id
).
sum
(
dim
=-
1
)
tokenized_list
[
"input_ids"
].
ne
(
tokenizer
.
pad_token_id
).
sum
(
dim
=-
1
)
...
@@ -103,8 +98,7 @@ def preprocess(
...
@@ -103,8 +98,7 @@ def preprocess(
"""Preprocess the data by tokenizing."""
"""Preprocess the data by tokenizing."""
examples
=
[
s
+
t
for
s
,
t
in
zip
(
sources
,
targets
)]
examples
=
[
s
+
t
for
s
,
t
in
zip
(
sources
,
targets
)]
examples_tokenized
,
sources_tokenized
=
[
examples_tokenized
,
sources_tokenized
=
[
_tokenize_fn
(
strings
,
tokenizer
,
max_length
)
_tokenize_fn
(
strings
,
tokenizer
,
max_length
)
for
strings
in
(
examples
,
sources
)
for
strings
in
(
examples
,
sources
)
]
]
input_ids
=
examples_tokenized
[
"input_ids"
]
input_ids
=
examples_tokenized
[
"input_ids"
]
labels
=
copy
.
deepcopy
(
input_ids
)
labels
=
copy
.
deepcopy
(
input_ids
)
...
@@ -116,7 +110,11 @@ def preprocess(
...
@@ -116,7 +110,11 @@ def preprocess(
class
SupervisedDataset
(
Dataset
):
class
SupervisedDataset
(
Dataset
):
"""Dataset for supervised fine-tuning."""
"""Dataset for supervised fine-tuning."""
def
__init__
(
self
,
data_path
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_datasets_size
:
int
=
None
,
max_length
:
int
=
512
):
def
__init__
(
self
,
data_path
:
str
,
tokenizer
:
transformers
.
PreTrainedTokenizer
,
max_datasets_size
:
int
=
None
,
max_length
:
int
=
512
):
super
(
SupervisedDataset
,
self
).
__init__
()
super
(
SupervisedDataset
,
self
).
__init__
()
logger
.
info
(
"Loading data..."
)
logger
.
info
(
"Loading data..."
)
list_data_dict
=
jload
(
data_path
)
list_data_dict
=
jload
(
data_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment