Commit a00ddd57 authored by Mark Daoust's avatar Mark Daoust Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 444922482
parent dbc24083
......@@ -29,7 +29,10 @@ from absl import logging
import tensorflow as tf
from tensorflow_docs.api_generator import doc_controls
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import parser
from tensorflow_docs.api_generator import public_api
from tensorflow_docs.api_generator.pretty_docs import base_page
from tensorflow_docs.api_generator.pretty_docs import function_page
import tensorflow_models as tfm
......@@ -52,6 +55,44 @@ PROJECT_SHORT_NAME = 'tfm'
PROJECT_FULL_NAME = 'TensorFlow Modeling Library'
class ExpFactoryInfo(function_page.FunctionPageInfo):
"""Customize the page for the experiment factory."""
def collect_docs(self):
super().collect_docs()
self.doc.docstring_parts.append(self.make_factory_options_table())
def make_factory_options_table(self):
lines = [
'',
'Allowed values for `exp_name`:',
'',
# The indent is important here, it keeps the site's markdown parser
# from switching to HTML mode.
' <table>\n',
'<th><code>exp_name</code></th><th>Description</th>',
]
reference_resolver = self.parser_config.reference_resolver
api_tree = self.parser_config.api_tree
for name, fn in sorted(tfm.core.exp_factory._REGISTERED_CONFIGS.items()): # pylint: disable=protected-access
fn_api_node = api_tree.node_for_object(fn)
if fn_api_node is None:
location = parser.get_defined_in(self.py_object, self.parser_config)
link = base_page.small_source_link(location, name)
else:
link = reference_resolver.python_link(name, fn_api_node.full_name)
doc = fn.__doc__
if doc:
doc = doc.splitlines()[0]
else:
doc = ''
lines.append(f'<tr><td>{link}</td><td>{doc}</td></tr>')
lines.append('</table>')
return '\n'.join(lines)
def hide_module_model_and_layer_methods():
"""Hide methods and properties defined in the base classes of Keras layers.
......@@ -103,6 +144,9 @@ def gen_api_docs(code_url_prefix, site_path, output_dir, project_short_name,
del tfm.nlp.layers.MultiHeadAttention
del tfm.nlp.layers.EinsumDense
doc_controls.set_custom_page_builder_cls(tfm.core.exp_factory.get_exp_config,
ExpFactoryInfo)
url_parts = code_url_prefix.strip('/').split('/')
url_parts = url_parts[:url_parts.index('tensorflow_models')]
url_parts.append('official')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment