Unverified Commit 802984ad authored by Omar Sanseviero's avatar Omar Sanseviero Committed by GitHub
Browse files

Fix and document Zero Shot Image Classification (#16079)

parent 6e1e88fd
...@@ -39,6 +39,7 @@ There are two categories of pipeline abstractions to be aware about: ...@@ -39,6 +39,7 @@ There are two categories of pipeline abstractions to be aware about:
- [`TokenClassificationPipeline`] - [`TokenClassificationPipeline`]
- [`TranslationPipeline`] - [`TranslationPipeline`]
- [`ZeroShotClassificationPipeline`] - [`ZeroShotClassificationPipeline`]
- [`ZeroShotImageClassificationPipeline`]
## The pipeline abstraction ## The pipeline abstraction
......
...@@ -245,7 +245,7 @@ SUPPORTED_TASKS = { ...@@ -245,7 +245,7 @@ SUPPORTED_TASKS = {
"impl": ZeroShotImageClassificationPipeline, "impl": ZeroShotImageClassificationPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (), "tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (), "pt": (AutoModel,) if is_torch_available() else (),
"default": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}, "default": {"model": {"pt": "openai/clip-vit-base-patch32", "tf": "openai/clip-vit-base-patch32"}},
"type": "multimodal", "type": "multimodal",
}, },
"conversational": { "conversational": {
...@@ -346,6 +346,7 @@ def check_task(task: str) -> Tuple[Dict, Any]: ...@@ -346,6 +346,7 @@ def check_task(task: str) -> Tuple[Dict, Any]:
- `"translation_xx_to_yy"` - `"translation_xx_to_yy"`
- `"summarization"` - `"summarization"`
- `"zero-shot-classification"` - `"zero-shot-classification"`
- `"zero-shot-image-classification"`
Returns: Returns:
(task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline (task_defaults`dict`, task_options: (`tuple`, None)) The actual dictionary required to initialize the pipeline
......
...@@ -35,7 +35,7 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline): ...@@ -35,7 +35,7 @@ class ZeroShotImageClassificationPipeline(ChunkPipeline):
`"zero-shot-image-classification"`. `"zero-shot-image-classification"`.
See the list of available models on See the list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=zer-shot-image-classification). [huggingface.co/models](https://huggingface.co/models?filter=zero-shot-image-classification).
""" """
def __init__(self, **kwargs): def __init__(self, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment