You need to sign in or sign up before continuing.
Unverified Commit 4139054b authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

v1.4.1 (#1568)

parent 0f2daad8
This diff is collapsed.
...@@ -9,7 +9,7 @@ members = [ ...@@ -9,7 +9,7 @@ members = [
resolver = "2" resolver = "2"
[workspace.package] [workspace.package]
version = "1.4.0" version = "1.4.1"
edition = "2021" edition = "2021"
authors = ["Olivier Dehaene"] authors = ["Olivier Dehaene"]
homepage = "https://github.com/huggingface/text-generation-inference" homepage = "https://github.com/huggingface/text-generation-inference"
......
...@@ -225,7 +225,7 @@ COPY server/Makefile server/Makefile ...@@ -225,7 +225,7 @@ COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_cuda.txt && \ pip install -r requirements_cuda.txt && \
pip install ".[bnb, accelerate, quantize, peft]" --no-cache-dir pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
......
...@@ -150,7 +150,7 @@ COPY server/Makefile server/Makefile ...@@ -150,7 +150,7 @@ COPY server/Makefile server/Makefile
RUN cd server && \ RUN cd server && \
make gen-server && \ make gen-server && \
pip install -r requirements_rocm.txt && \ pip install -r requirements_rocm.txt && \
pip install ".[accelerate, peft]" --no-cache-dir pip install ".[accelerate, peft, outlines]" --no-cache-dir
# Install benchmarker # Install benchmarker
COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
"name": "Apache 2.0", "name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0" "url": "https://www.apache.org/licenses/LICENSE-2.0"
}, },
"version": "1.4.0" "version": "1.4.1"
}, },
"paths": { "paths": {
"/": { "/": {
...@@ -590,8 +590,11 @@ ...@@ -590,8 +590,11 @@
"minimum": 0 "minimum": 0
}, },
"logprobs": { "logprobs": {
"type": "number", "allOf": [
"format": "float", {
"$ref": "#/components/schemas/ChatCompletionLogprobs"
}
],
"nullable": true "nullable": true
} }
} }
...@@ -710,7 +713,7 @@ ...@@ -710,7 +713,7 @@
"presence_penalty": { "presence_penalty": {
"type": "number", "type": "number",
"format": "float", "format": "float",
"description": "UNUSED\nNumber between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics", "description": "Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far,\nincreasing the model's likelihood to talk about new topics",
"example": 0.1, "example": 0.1,
"nullable": true "nullable": true
}, },
...@@ -734,7 +737,7 @@ ...@@ -734,7 +737,7 @@
"top_logprobs": { "top_logprobs": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
"description": "UNUSED\nAn integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.", "description": "An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with\nan associated log probability. logprobs must be set to true if this parameter is used.",
"example": "5", "example": "5",
"nullable": true, "nullable": true,
"minimum": 0 "minimum": 0
...@@ -870,6 +873,22 @@ ...@@ -870,6 +873,22 @@
"default": "false", "default": "false",
"example": true "example": true
}, },
"frequency_penalty": {
"type": "number",
"format": "float",
"default": "null",
"example": 0.1,
"nullable": true,
"exclusiveMinimum": -2
},
"grammar": {
"allOf": [
{
"$ref": "#/components/schemas/GrammarType"
}
],
"nullable": true
},
"max_new_tokens": { "max_new_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
...@@ -1026,6 +1045,12 @@ ...@@ -1026,6 +1045,12 @@
"example": "null", "example": "null",
"nullable": true "nullable": true
}, },
"max_batch_size": {
"type": "integer",
"example": "null",
"nullable": true,
"minimum": 0
},
"max_batch_total_tokens": { "max_batch_total_tokens": {
"type": "integer", "type": "integer",
"format": "int32", "format": "int32",
...@@ -1119,6 +1144,11 @@ ...@@ -1119,6 +1144,11 @@
"type": "string", "type": "string",
"example": "My name is David and I" "example": "My name is David and I"
}, },
"name": {
"type": "string",
"example": "\"David\"",
"nullable": true
},
"role": { "role": {
"type": "string", "type": "string",
"example": "user" "example": "user"
......
[tool.poetry] [tool.poetry]
name = "text-generation-integration-tests" name = "text-generation-integration-tests"
version = "1.4.0" version = "1.4.1"
description = "Text Generation Inference integration tests" description = "Text Generation Inference integration tests"
authors = ["Nicolas Patry <nicolas@huggingface.co>"] authors = ["Nicolas Patry <nicolas@huggingface.co>"]
......
...@@ -23,7 +23,7 @@ install-megablocks: ...@@ -23,7 +23,7 @@ install-megablocks:
install: gen-server install: gen-server
pip install pip --upgrade pip install pip --upgrade
pip install -r requirements_cuda.txt pip install -r requirements_cuda.txt
pip install -e ".[bnb, accelerate, quantize, peft]" pip install -e ".[bnb, accelerate, quantize, peft, outlines]"
run-dev: run-dev:
SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded SAFETENSORS_FAST_GPU=1 python -m torch.distributed.run --nproc_per_node=2 text_generation_server/cli.py serve bigscience/bloom-560m --sharded
......
This source diff could not be displayed because it is too large. You can view the blob instead.
[tool.poetry] [tool.poetry]
name = "text-generation-server" name = "text-generation-server"
version = "1.4.0" version = "1.4.1"
description = "Text Generation Inference Python gRPC Server" description = "Text Generation Inference Python gRPC Server"
authors = ["Olivier Dehaene <olivier@huggingface.co>"] authors = ["Olivier Dehaene <olivier@huggingface.co>"]
...@@ -34,7 +34,7 @@ peft = { version = "^0.8.2", optional = true } ...@@ -34,7 +34,7 @@ peft = { version = "^0.8.2", optional = true }
torch = { version = "^2.1.1", optional = true } torch = { version = "^2.1.1", optional = true }
scipy = "^1.11.1" scipy = "^1.11.1"
pillow = "^10.0.0" pillow = "^10.0.0"
outlines="^0.0.27" outlines= { version = "^0.0.27", optional = true }
[tool.poetry.extras] [tool.poetry.extras]
torch = ["torch"] torch = ["torch"]
...@@ -42,6 +42,7 @@ accelerate = ["accelerate"] ...@@ -42,6 +42,7 @@ accelerate = ["accelerate"]
bnb = ["bitsandbytes"] bnb = ["bitsandbytes"]
peft = ["peft"] peft = ["peft"]
quantize = ["texttable", "datasets", "accelerate"] quantize = ["texttable", "datasets", "accelerate"]
outlines = ["outlines"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
grpcio-tools = "^1.51.1" grpcio-tools = "^1.51.1"
......
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13" bitsandbytes==0.41.3.post2 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
...@@ -10,14 +10,14 @@ filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" ...@@ -10,14 +10,14 @@ filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -29,19 +29,19 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" ...@@ -29,19 +29,19 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13" backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
certifi==2023.11.17 ; python_version >= "3.9" and python_version < "3.13" certifi==2024.2.2 ; python_version >= "3.9" and python_version < "3.13"
charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13" charset-normalizer==3.3.2 ; python_version >= "3.9" and python_version < "3.13"
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13" click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows") colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
...@@ -9,14 +9,14 @@ filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13" ...@@ -9,14 +9,14 @@ filelock==3.13.1 ; python_version >= "3.9" and python_version < "3.13"
fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13" fsspec==2023.10.0 ; python_version >= "3.9" and python_version < "3.13"
googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13" googleapis-common-protos==1.62.0 ; python_version >= "3.9" and python_version < "3.13"
grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13" grpc-interceptor==0.15.4 ; python_version >= "3.9" and python_version < "3.13"
grpcio-reflection==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-reflection==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio-status==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio-status==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
grpcio==1.60.0 ; python_version >= "3.9" and python_version < "3.13" grpcio==1.60.1 ; python_version >= "3.9" and python_version < "3.13"
hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13" hf-transfer==0.1.5 ; python_version >= "3.9" and python_version < "3.13"
huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13" huggingface-hub==0.19.4 ; python_version >= "3.9" and python_version < "3.13"
idna==3.6 ; python_version >= "3.9" and python_version < "3.13" idna==3.6 ; python_version >= "3.9" and python_version < "3.13"
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13" loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
numpy==1.26.3 ; python_version >= "3.9" and python_version < "3.13" numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
...@@ -28,19 +28,19 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13" ...@@ -28,19 +28,19 @@ opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13" opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
packaging==23.2 ; python_version >= "3.9" and python_version < "3.13" packaging==23.2 ; python_version >= "3.9" and python_version < "3.13"
pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13" pillow==10.2.0 ; python_version >= "3.9" and python_version < "3.13"
protobuf==4.25.2 ; python_version >= "3.9" and python_version < "3.13" protobuf==4.25.3 ; python_version >= "3.9" and python_version < "3.13"
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13" pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13" regex==2023.12.25 ; python_version >= "3.9" and python_version < "3.13"
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13" requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13" safetensors==0.3.3 ; python_version >= "3.9" and python_version < "3.13"
scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13" scipy==1.12.0 ; python_version >= "3.9" and python_version < "3.13"
sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13" sentencepiece==0.1.99 ; python_version >= "3.9" and python_version < "3.13"
setuptools==69.0.3 ; python_version >= "3.9" and python_version < "3.13" setuptools==69.1.0 ; python_version >= "3.9" and python_version < "3.13"
tokenizers==0.15.1 ; python_version >= "3.9" and python_version < "3.13" tokenizers==0.15.2 ; python_version >= "3.9" and python_version < "3.13"
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13" tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13" transformers==4.37.1 ; python_version >= "3.9" and python_version < "3.13"
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13" typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13" typing-extensions==4.9.0 ; python_version >= "3.9" and python_version < "3.13"
urllib3==2.1.0 ; python_version >= "3.9" and python_version < "3.13" urllib3==2.2.0 ; python_version >= "3.9" and python_version < "3.13"
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32" win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13" wrapt==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
import math import math
import torch import torch
import json
from loguru import logger from loguru import logger
from functools import lru_cache from typing import Dict, Union
from typing import Optional, List, Dict, Union
from text_generation_server.pb.generate_pb2 import GrammarType from text_generation_server.pb.generate_pb2 import GrammarType
from outlines.fsm.fsm import RegexFSM from outlines.fsm.fsm import RegexFSM
...@@ -492,7 +490,7 @@ class GrammarLogitProcessor(LogitsProcessor): ...@@ -492,7 +490,7 @@ class GrammarLogitProcessor(LogitsProcessor):
if fsm_grammar_state == -1 or self.fsm is None: if fsm_grammar_state == -1 or self.fsm is None:
return logits return logits
allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state) allowed_tokens = self.fsm.allowed_token_ids(fsm_grammar_state)
mask = torch.full((logits.shape[-1],), -math.inf, device=self.device) mask = torch.full_like(logits, -math.inf)
mask[allowed_tokens] = 0 mask[allowed_tokens] = 0
biased_scores = logits + mask biased_scores = logits + mask
return biased_scores return biased_scores
...@@ -550,22 +548,15 @@ class GrammarLogitProcessor(LogitsProcessor): ...@@ -550,22 +548,15 @@ class GrammarLogitProcessor(LogitsProcessor):
logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s") logger.debug(f"Adapted tokenizer in {time.time() - start_time:.2f}s")
return tokenizer return tokenizer
def filter(self, indices):
new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self
class HeterogeneousGrammarLogitProcessor(LogitsProcessor): class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
def __init__(self, tokenizer, device, grammars, grammar_type): def __init__(self, tokenizer, device, grammars, grammar_types):
self.device = device self.device = device
self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer) self.tokenizer = GrammarLogitProcessor._cached_adapt_tokenizer(tokenizer)
self.fsms = [] self.fsms = []
for i in range(len(grammars)): for grammar, grammar_type in zip(grammars, grammar_types):
fsm = GrammarLogitProcessor._cached_compile_fsm( fsm = GrammarLogitProcessor._cached_compile_fsm(
grammar_type[i], grammars[i], self.tokenizer grammar_type, grammar, self.tokenizer
) )
self.fsms.append(fsm) self.fsms.append(fsm)
...@@ -573,7 +564,6 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ...@@ -573,7 +564,6 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
self, self,
logits: torch.Tensor, logits: torch.Tensor,
fsm_grammar_states: List[int], fsm_grammar_states: List[int],
mask: torch.Tensor,
): ):
mask = torch.full_like(logits, -math.inf) mask = torch.full_like(logits, -math.inf)
for i in range(logits.shape[0]): for i in range(logits.shape[0]):
...@@ -585,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ...@@ -585,7 +575,7 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
logits += mask logits += mask
return logits return logits
def advance_batch(self, next_token_ids, fsm_grammar_states, grammars): def advance_batch(self, next_token_ids, fsm_grammar_states):
return [ return [
GrammarLogitProcessor._advance( GrammarLogitProcessor._advance(
next_token_ids[i], fsm_grammar_states[i], self.fsms[i] next_token_ids[i], fsm_grammar_states[i], self.fsms[i]
...@@ -599,4 +589,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor): ...@@ -599,4 +589,8 @@ class HeterogeneousGrammarLogitProcessor(LogitsProcessor):
) )
def filter(self, indices): def filter(self, indices):
return GrammarLogitProcessor.filter(self, indices) new_fsms = []
for i in indices:
new_fsms.append(self.fsms[i])
self.fsms = new_fsms
return self
...@@ -341,7 +341,7 @@ class HeterogeneousNextTokenChooser: ...@@ -341,7 +341,7 @@ class HeterogeneousNextTokenChooser:
for warper in self.warpers: for warper in self.warpers:
_scores = warper(input_ids, _scores) _scores = warper(input_ids, _scores)
if self.grammar_processor is not None: if self.grammar_processor is not None:
_scores = self.grammar_processor(_scores, self.fsm_grammar_states, mask) _scores = self.grammar_processor(_scores, self.fsm_grammar_states)
_next_ids = self.choice(_scores) _next_ids = self.choice(_scores)
scores[:, j] = _scores scores[:, j] = _scores
next_ids[:, j] = _next_ids next_ids[:, j] = _next_ids
...@@ -402,7 +402,7 @@ class HeterogeneousNextTokenChooser: ...@@ -402,7 +402,7 @@ class HeterogeneousNextTokenChooser:
def advance_grammar(self, next_ids: List[int]): def advance_grammar(self, next_ids: List[int]):
if self.grammar_processor is not None: if self.grammar_processor is not None:
other_new_states = self.grammar_processor.advance_batch( other_new_states = self.grammar_processor.advance_batch(
next_ids, self.fsm_grammar_states, self.grammars next_ids, self.fsm_grammar_states
) )
self.fsm_grammar_states = other_new_states self.fsm_grammar_states = other_new_states
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment