Unverified Commit a6d178e2 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`DocString`] Support a revision in the docstring `add_code_sample_docstrings`...


[`DocString`] Support a revision in the docstring `add_code_sample_docstrings` to facilitate integrations (#27645)

* initial commit

* dummy changes

* style

* Update src/transformers/utils/doc.py
Co-authored-by: default avatarAlex McKinney <44398246+vvvm23@users.noreply.github.com>

* nits

* nit use ` if re.match(r'^refs/pr/\d*', revision):`

* restrict

* nit

* test the doc vuilder

* wow

* oke the order was wrong

---------
Co-authored-by: default avatarAlex McKinney <44398246+vvvm23@users.noreply.github.com>
parent 2098d343
...@@ -1267,13 +1267,14 @@ def overwrite_call_docstring(model_class, docstring): ...@@ -1267,13 +1267,14 @@ def overwrite_call_docstring(model_class, docstring):
model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__) model_class.__call__ = add_start_docstrings_to_model_forward(docstring)(model_class.__call__)
def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None): def append_call_sample_docstring(model_class, checkpoint, output_type, config_class, mask=None, revision=None):
model_class.__call__ = copy_func(model_class.__call__) model_class.__call__ = copy_func(model_class.__call__)
model_class.__call__ = add_code_sample_docstrings( model_class.__call__ = add_code_sample_docstrings(
checkpoint=checkpoint, checkpoint=checkpoint,
output_type=output_type, output_type=output_type,
config_class=config_class, config_class=config_class,
model_cls=model_class.__name__, model_cls=model_class.__name__,
revision=revision,
)(model_class.__call__) )(model_class.__call__)
......
...@@ -829,7 +829,9 @@ class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel): ...@@ -829,7 +829,9 @@ class FlaxAlbertForMaskedLM(FlaxAlbertPreTrainedModel):
module_class = FlaxAlbertForMaskedLMModule module_class = FlaxAlbertForMaskedLMModule
append_call_sample_docstring(FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC) append_call_sample_docstring(
FlaxAlbertForMaskedLM, _CHECKPOINT_FOR_DOC, FlaxMaskedLMOutput, _CONFIG_FOR_DOC, revision="refs/pr/11"
)
class FlaxAlbertForSequenceClassificationModule(nn.Module): class FlaxAlbertForSequenceClassificationModule(nn.Module):
......
...@@ -1075,6 +1075,7 @@ def add_code_sample_docstrings( ...@@ -1075,6 +1075,7 @@ def add_code_sample_docstrings(
expected_output=None, expected_output=None,
expected_loss=None, expected_loss=None,
real_checkpoint=None, real_checkpoint=None,
revision=None,
): ):
def docstring_decorator(fn): def docstring_decorator(fn):
# model_class defaults to function's class if not specified otherwise # model_class defaults to function's class if not specified otherwise
...@@ -1143,6 +1144,15 @@ def add_code_sample_docstrings( ...@@ -1143,6 +1144,15 @@ def add_code_sample_docstrings(
func_doc = (fn.__doc__ or "") + "".join(docstr) func_doc = (fn.__doc__ or "") + "".join(docstr)
output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class) output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
built_doc = code_sample.format(**doc_kwargs) built_doc = code_sample.format(**doc_kwargs)
if revision is not None:
if re.match(r"^refs/pr/\\d+", revision):
raise ValueError(
f"The provided revision '{revision}' is incorrect. It should point to"
" a pull request reference on the hub like 'refs/pr/6'"
)
built_doc = built_doc.replace(
f'from_pretrained("{checkpoint}")', f'from_pretrained("{checkpoint}", revision="{revision}")'
)
fn.__doc__ = func_doc + output_doc + built_doc fn.__doc__ = func_doc + output_doc + built_doc
return fn return fn
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment