Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e0e2da11
Unverified
Commit
e0e2da11
authored
Nov 22, 2021
by
Daniel Stancl
Committed by
GitHub
Nov 22, 2021
Browse files
Improve a add-new-pipeline docs a bit (#14485)
parent
a4553e6c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
docs/source/add_new_pipeline.rst
docs/source/add_new_pipeline.rst
+7
-7
No files found.
docs/source/add_new_pipeline.rst
View file @
e0e2da11
...
@@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to
...
@@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to
from transformers import Pipeline
from transformers import Pipeline
class MyPipeline(Pipeline):
class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs)
def _sanitize_parameters(self, **kwargs)
:
preprocess_kwargs = {}
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {}
return preprocess_kwargs, {}, {}
def preprocess(self, inputs, maybe_arg=2)
def preprocess(self, inputs, maybe_arg=2)
:
model_input = Tensor(....)
model_input = Tensor(....)
return {"model_input": model_input}
return {"model_input": model_input}
def _forward(self, model_inputs)
def _forward(self, model_inputs)
:
# model_inputs == {"model_input": model_input}
# model_inputs == {"model_input": model_input}
oututs = self.model(**model_inputs)
out
p
uts = self.model(**model_inputs)
# Maybe {"logits": Tensor(...)}
# Maybe {"logits": Tensor(...)}
return outputs
return outputs
def postprocess(self, model_outputs)
def postprocess(self, model_outputs)
:
best_class = model_outputs["logits"].softmax(-1)
best_class = model_outputs["logits"].softmax(-1)
return best_class
return best_class
...
@@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa
...
@@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa
.. code-block::
.. code-block::
def postprocess(self, model_outputs, top_k=5)
def postprocess(self, model_outputs, top_k=5)
:
best_class = model_outputs["logits"].softmax(-1)
best_class = model_outputs["logits"].softmax(-1)
# Add logic to handle top_k
# Add logic to handle top_k
return best_class
return best_class
def _sanitize_parameters(self, **kwargs)
def _sanitize_parameters(self, **kwargs)
:
preprocess_kwargs = {}
preprocess_kwargs = {}
if "maybe_arg" in kwargs:
if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment