"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "3bac800e43741f1bd58411c67aa2b2377cfbb572"
Unverified Commit ab229663 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Fix QA sample (#16648)



* fix QA sample

* For TF_QUESTION_ANSWERING_SAMPLE
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent 9a24b97b
...@@ -207,7 +207,8 @@ PT_QUESTION_ANSWERING_SAMPLE = r""" ...@@ -207,7 +207,8 @@ PT_QUESTION_ANSWERING_SAMPLE = r"""
```python ```python
>>> # target is "nice puppet" >>> # target is "nice puppet"
>>> target_start_index, target_end_index = torch.tensor([14]), torch.tensor([15]) >>> target_start_index = torch.tensor([{qa_target_start_index}])
>>> target_end_index = torch.tensor([{qa_target_end_index}])
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
>>> loss = outputs.loss >>> loss = outputs.loss
...@@ -667,7 +668,8 @@ TF_QUESTION_ANSWERING_SAMPLE = r""" ...@@ -667,7 +668,8 @@ TF_QUESTION_ANSWERING_SAMPLE = r"""
```python ```python
>>> # target is "nice puppet" >>> # target is "nice puppet"
>>> target_start_index, target_end_index = tf.constant([14]), tf.constant([15]) >>> target_start_index = tf.constant([{qa_target_start_index}])
>>> target_end_index = tf.constant([{qa_target_end_index}])
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index) >>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
>>> loss = tf.math.reduce_mean(outputs.loss) >>> loss = tf.math.reduce_mean(outputs.loss)
...@@ -1054,6 +1056,8 @@ def add_code_sample_docstrings( ...@@ -1054,6 +1056,8 @@ def add_code_sample_docstrings(
output_type=None, output_type=None,
config_class=None, config_class=None,
mask="[MASK]", mask="[MASK]",
qa_target_start_index=14,
qa_target_end_index=15,
model_cls=None, model_cls=None,
modality=None, modality=None,
expected_output="", expected_output="",
...@@ -1078,6 +1082,8 @@ def add_code_sample_docstrings( ...@@ -1078,6 +1082,8 @@ def add_code_sample_docstrings(
processor_class=processor_class, processor_class=processor_class,
checkpoint=checkpoint, checkpoint=checkpoint,
mask=mask, mask=mask,
qa_target_start_index=qa_target_start_index,
qa_target_end_index=qa_target_end_index,
expected_output=expected_output, expected_output=expected_output,
expected_loss=expected_loss, expected_loss=expected_loss,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment