Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6a02e980
"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "f347d664acf61b00ca68ac5efcabcbdd4079534f"
Unverified
Commit
6a02e980
authored
Apr 06, 2023
by
Nicolas Patry
Committed by
GitHub
Apr 06, 2023
Browse files
LlamaTokenizerFast Fix (.., from_slow=True). (#22630)
parent
09a9888f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
0 deletions
+37
-0
src/transformers/models/llama/tokenization_llama_fast.py
src/transformers/models/llama/tokenization_llama_fast.py
+37
-0
No files found.
src/transformers/models/llama/tokenization_llama_fast.py
View file @
6a02e980
...
...
@@ -12,12 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
from
shutil
import
copyfile
from
typing
import
Optional
,
Tuple
from
...tokenization_utils_fast
import
PreTrainedTokenizerFast
from
...utils
import
is_sentencepiece_available
,
logging
from
...utils.versions
import
require_version
require_version
(
"tokenizers>=0.13.3"
)
if
is_sentencepiece_available
():
from
.tokenization_llama
import
LlamaTokenizer
else
:
LlamaTokenizer
=
None
logger
=
logging
.
get_logger
(
__name__
)
VOCAB_FILES_NAMES
=
{
"vocab_file"
:
"tokenizer.model"
,
"tokenizer_file"
:
"tokenizer.json"
}
class
LlamaTokenizerFast
(
PreTrainedTokenizerFast
):
"""
...
...
@@ -59,6 +72,8 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
token instead.
"""
vocab_files_names
=
VOCAB_FILES_NAMES
slow_tokenizer_class
=
LlamaTokenizer
padding_side
=
"left"
def
__init__
(
...
...
@@ -80,3 +95,25 @@ class LlamaTokenizerFast(PreTrainedTokenizerFast):
eos_token
=
eos_token
,
**
kwargs
,
)
self
.
vocab_file
=
vocab_file
self
.
can_save_slow_tokenizer
=
False
if
not
self
.
vocab_file
else
True
def
save_vocabulary
(
self
,
save_directory
:
str
,
filename_prefix
:
Optional
[
str
]
=
None
)
->
Tuple
[
str
]:
if
not
self
.
can_save_slow_tokenizer
:
raise
ValueError
(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if
not
os
.
path
.
isdir
(
save_directory
):
logger
.
error
(
f
"Vocabulary path (
{
save_directory
}
) should be a directory"
)
return
out_vocab_file
=
os
.
path
.
join
(
save_directory
,
(
filename_prefix
+
"-"
if
filename_prefix
else
""
)
+
VOCAB_FILES_NAMES
[
"vocab_file"
]
)
if
os
.
path
.
abspath
(
self
.
vocab_file
)
!=
os
.
path
.
abspath
(
out_vocab_file
):
copyfile
(
self
.
vocab_file
,
out_vocab_file
)
return
(
out_vocab_file
,)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment