Unverified Commit e0e2da11 authored by Daniel Stancl's avatar Daniel Stancl Committed by GitHub
Browse files

Improve a add-new-pipeline docs a bit (#14485)

parent a4553e6c
...@@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to ...@@ -29,23 +29,23 @@ Start by inheriting the base class :obj:`Pipeline`. with the 4 methods needed to
from transformers import Pipeline from transformers import Pipeline
class MyPipeline(Pipeline): class MyPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs) def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {} preprocess_kwargs = {}
if "maybe_arg" in kwargs: if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
return preprocess_kwargs, {}, {} return preprocess_kwargs, {}, {}
def preprocess(self, inputs, maybe_arg=2) def preprocess(self, inputs, maybe_arg=2):
model_input = Tensor(....) model_input = Tensor(....)
return {"model_input": model_input} return {"model_input": model_input}
def _forward(self, model_inputs) def _forward(self, model_inputs):
# model_inputs == {"model_input": model_input} # model_inputs == {"model_input": model_input}
oututs = self.model(**model_inputs) outputs = self.model(**model_inputs)
# Maybe {"logits": Tensor(...)} # Maybe {"logits": Tensor(...)}
return outputs return outputs
def postprocess(self, model_outputs) def postprocess(self, model_outputs):
best_class = model_outputs["logits"].softmax(-1) best_class = model_outputs["logits"].softmax(-1)
return best_class return best_class
...@@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa ...@@ -89,12 +89,12 @@ In order to achieve that, we'll update our :obj:`postprocess` method with a defa
.. code-block:: .. code-block::
def postprocess(self, model_outputs, top_k=5) def postprocess(self, model_outputs, top_k=5):
best_class = model_outputs["logits"].softmax(-1) best_class = model_outputs["logits"].softmax(-1)
# Add logic to handle top_k # Add logic to handle top_k
return best_class return best_class
def _sanitize_parameters(self, **kwargs) def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {} preprocess_kwargs = {}
if "maybe_arg" in kwargs: if "maybe_arg" in kwargs:
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment