"docs/source/ja/main_classes/model.md" did not exist on "9093b19b13f642ed63e0fa49f4091fc0283a84e3"
Unverified Commit b5f06d6c authored by Connor Boyle's avatar Connor Boyle Committed by GitHub
Browse files

Raise error if `stride` is too high in `TokenClassificationPipeline` (#22942)

* Raise error if `stride` is too high

* Clarify use of `stride`
parent 3f6a4b5b
......@@ -69,7 +69,8 @@ class AggregationStrategy(ExplicitEnum):
stride (`int`, *optional*):
If stride is provided, the pipeline is applied on all the text. The text is split into chunks of size
model_max_length. Works only with fast tokenizers and `aggregation_strategy` different from `NONE`. The
value of this argument defines the number of overlapping tokens between chunks.
value of this argument defines the number of overlapping tokens between chunks. In other words, the model
will shift forward by `tokenizer.model_max_length - stride` tokens each step.
aggregation_strategy (`str`, *optional*, defaults to `"none"`):
The strategy to fuse (or not) tokens based on the model prediction.
......@@ -191,6 +192,10 @@ class TokenClassificationPipeline(ChunkPipeline):
if ignore_labels is not None:
postprocess_params["ignore_labels"] = ignore_labels
if stride is not None:
if stride >= self.tokenizer.model_max_length:
raise ValueError(
"`stride` must be less than `tokenizer.model_max_length` (or even lower if the tokenizer adds special tokens)"
)
if aggregation_strategy == AggregationStrategy.NONE:
raise ValueError(
"`stride` was provided to process all the text but `aggregation_strategy="
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment