Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
400e76ef
Unverified
Commit
400e76ef
authored
Jul 27, 2023
by
Sylvain Gugger
Committed by
GitHub
Jul 27, 2023
Browse files
Add new model in doc table of content (#25148)
parent
e9310363
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
53 additions
and
0 deletions
+53
-0
src/transformers/commands/add_new_model_like.py
src/transformers/commands/add_new_model_like.py
+53
-0
No files found.
src/transformers/commands/add_new_model_like.py
View file @
400e76ef
...
@@ -23,6 +23,8 @@ from itertools import chain
...
@@ -23,6 +23,8 @@ from itertools import chain
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Pattern
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Pattern
,
Tuple
,
Union
import
yaml
from
..models
import
auto
as
auto_module
from
..models
import
auto
as
auto_module
from
..models.auto.configuration_auto
import
model_type_to_module_name
from
..models.auto.configuration_auto
import
model_type_to_module_name
from
..utils
import
is_flax_available
,
is_tf_available
,
is_torch_available
,
logging
from
..utils
import
is_flax_available
,
is_tf_available
,
is_torch_available
,
logging
...
@@ -1268,6 +1270,56 @@ def duplicate_doc_file(
...
@@ -1268,6 +1270,56 @@ def duplicate_doc_file(
f
.
write
(
"
\n
"
.
join
(
new_blocks
))
f
.
write
(
"
\n
"
.
join
(
new_blocks
))
def
insert_model_in_doc_toc
(
old_model_patterns
,
new_model_patterns
):
"""
Insert the new model in the doc TOC, in the same section as the old model.
Args:
old_model_patterns (`ModelPatterns`): The patterns for the old model.
new_model_patterns (`ModelPatterns`): The patterns for the new model.
"""
toc_file
=
REPO_PATH
/
"docs"
/
"source"
/
"en"
/
"_toctree.yml"
with
open
(
toc_file
,
"r"
,
encoding
=
"utf8"
)
as
f
:
content
=
yaml
.
safe_load
(
f
)
# Get to the model API doc
api_idx
=
0
while
content
[
api_idx
][
"title"
]
!=
"API"
:
api_idx
+=
1
api_doc
=
content
[
api_idx
][
"sections"
]
model_idx
=
0
while
api_doc
[
model_idx
][
"title"
]
!=
"Models"
:
model_idx
+=
1
model_doc
=
api_doc
[
model_idx
][
"sections"
]
# Find the base model in the Toc
old_model_type
=
old_model_patterns
.
model_type
section_idx
=
0
while
section_idx
<
len
(
model_doc
):
sections
=
[
entry
[
"local"
]
for
entry
in
model_doc
[
section_idx
][
"sections"
]]
if
f
"model_doc/
{
old_model_type
}
"
in
sections
:
break
section_idx
+=
1
if
section_idx
==
len
(
model_doc
):
old_model
=
old_model_patterns
.
model_name
new_model
=
new_model_patterns
.
model_name
print
(
f
"Did not find
{
old_model
}
in the table of content, so you will need to add
{
new_model
}
manually."
)
return
# Add the new model in the same toc
toc_entry
=
{
"local"
:
f
"model_doc/
{
new_model_patterns
.
model_type
}
"
,
"title"
:
new_model_patterns
.
model_name
}
model_doc
[
section_idx
][
"sections"
].
append
(
toc_entry
)
model_doc
[
section_idx
][
"sections"
]
=
sorted
(
model_doc
[
section_idx
][
"sections"
],
key
=
lambda
s
:
s
[
"title"
].
lower
())
api_doc
[
model_idx
][
"sections"
]
=
model_doc
content
[
api_idx
][
"sections"
]
=
api_doc
with
open
(
toc_file
,
"w"
,
encoding
=
"utf-8"
)
as
f
:
f
.
write
(
yaml
.
dump
(
content
,
allow_unicode
=
True
))
def
create_new_model_like
(
def
create_new_model_like
(
model_type
:
str
,
model_type
:
str
,
new_model_patterns
:
ModelPatterns
,
new_model_patterns
:
ModelPatterns
,
...
@@ -1407,6 +1459,7 @@ def create_new_model_like(
...
@@ -1407,6 +1459,7 @@ def create_new_model_like(
# 5. Add doc file
# 5. Add doc file
doc_file
=
REPO_PATH
/
"docs"
/
"source"
/
"en"
/
"model_doc"
/
f
"
{
old_model_patterns
.
model_type
}
.md"
doc_file
=
REPO_PATH
/
"docs"
/
"source"
/
"en"
/
"model_doc"
/
f
"
{
old_model_patterns
.
model_type
}
.md"
duplicate_doc_file
(
doc_file
,
old_model_patterns
,
new_model_patterns
,
frameworks
=
frameworks
)
duplicate_doc_file
(
doc_file
,
old_model_patterns
,
new_model_patterns
,
frameworks
=
frameworks
)
insert_model_in_doc_toc
(
old_model_patterns
,
new_model_patterns
)
# 6. Warn the user for duplicate patterns
# 6. Warn the user for duplicate patterns
if
old_model_patterns
.
model_type
==
old_model_patterns
.
checkpoint
:
if
old_model_patterns
.
model_type
==
old_model_patterns
.
checkpoint
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment