Unverified Commit 5a582261 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix decode token (#334)



Fixes #333

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent dbdc587d
...@@ -213,12 +213,13 @@ jobs: ...@@ -213,12 +213,13 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }} sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install - name: Install
run: | run: |
pip install pytest-xdist
make install-integration-tests make install-integration-tests
- name: Run tests - name: Run tests
run: | run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }} export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests pytest -s -vv -n 2 --dist loadfile integration-tests
stop-runner: stop-runner:
name: Stop self-hosted EC2 runner name: Stop self-hosted EC2 runner
......
...@@ -66,7 +66,8 @@ jobs: ...@@ -66,7 +66,8 @@ jobs:
- name: Run server tests - name: Run server tests
run: | run: |
pip install pytest pip install pytest
make python-server-tests export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv server/tests
- name: Run Rust fmt - name: Run Rust fmt
run: | run: |
cargo fmt --check cargo fmt --check
......
...@@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests ...@@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests
pytest -s -vv --snapshot-update integration-tests pytest -s -vv --snapshot-update integration-tests
python-server-tests: python-server-tests:
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests
python-client-tests: python-client-tests:
pytest clients/python/tests pytest clients/python/tests
......
import sys
import subprocess import subprocess
import contextlib import contextlib
import pytest import pytest
...@@ -7,6 +8,7 @@ import docker ...@@ -7,6 +8,7 @@ import docker
import json import json
import math import math
import time import time
import random
from docker.errors import NotFound from docker.errors import NotFound
from typing import Optional, List, Dict from typing import Optional, List, Dict
...@@ -205,10 +207,12 @@ def launcher(event_loop): ...@@ -205,10 +207,12 @@ def launcher(event_loop):
def local_launcher( def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
): ):
port = 9999 port = random.randint(8000, 10_000)
master_port = 19999 master_port = random.randint(10_000, 20_000)
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server" shard_uds_path = (
f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
)
args = [ args = [
"text-generation-launcher", "text-generation-launcher",
...@@ -236,7 +240,7 @@ def launcher(event_loop): ...@@ -236,7 +240,7 @@ def launcher(event_loop):
process.wait(60) process.wait(60)
launcher_output = process.stdout.read().decode("utf-8") launcher_output = process.stdout.read().decode("utf-8")
print(launcher_output) print(launcher_output, file=sys.stderr)
process.stdout.close() process.stdout.close()
process.stderr.close() process.stderr.close()
...@@ -245,7 +249,7 @@ def launcher(event_loop): ...@@ -245,7 +249,7 @@ def launcher(event_loop):
def docker_launcher( def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
): ):
port = 9999 port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"] args = ["--model-id", model_id, "--env"]
...@@ -298,7 +302,7 @@ def launcher(event_loop): ...@@ -298,7 +302,7 @@ def launcher(event_loop):
pass pass
container_output = container.logs().decode("utf-8") container_output = container.logs().decode("utf-8")
print(container_output) print(container_output, file=sys.stderr)
container.remove() container.remove()
......
...@@ -25,25 +25,25 @@ ...@@ -25,25 +25,25 @@
"tokens": [ "tokens": [
{ {
"id": 363, "id": 363,
"logprob": -1.5322266, "logprob": -1.5380859,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 847, "id": 847,
"logprob": -2.5585938, "logprob": -2.5859375,
"special": false, "special": false,
"text": " /" "text": " /"
}, },
{ {
"id": 2754, "id": 2754,
"logprob": -2.265625, "logprob": -2.2695312,
"special": false, "special": false,
"text": "api" "text": "api"
}, },
{ {
"id": 29914, "id": 29914,
"logprob": -0.034088135, "logprob": -0.03439331,
"special": false, "special": false,
"text": "/" "text": "/"
}, },
...@@ -55,31 +55,31 @@ ...@@ -55,31 +55,31 @@
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -0.36816406, "logprob": -0.36694336,
"special": false, "special": false,
"text": "1" "text": "1"
}, },
{ {
"id": 29914, "id": 29914,
"logprob": -0.013191223, "logprob": -0.013114929,
"special": false, "special": false,
"text": "/" "text": "/"
}, },
{ {
"id": 16418, "id": 16418,
"logprob": -3.15625, "logprob": -3.1542969,
"special": false, "special": false,
"text": "projects" "text": "projects"
}, },
{ {
"id": 29914, "id": 29914,
"logprob": -0.43774414, "logprob": -0.43847656,
"special": false, "special": false,
"text": "/" "text": "/"
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -1.9443359, "logprob": -1.9433594,
"special": false, "special": false,
"text": "1" "text": "1"
} }
...@@ -113,25 +113,25 @@ ...@@ -113,25 +113,25 @@
"tokens": [ "tokens": [
{ {
"id": 363, "id": 363,
"logprob": -1.5380859, "logprob": -1.5322266,
"special": false, "special": false,
"text": " for" "text": " for"
}, },
{ {
"id": 847, "id": 847,
"logprob": -2.5859375, "logprob": -2.5585938,
"special": false, "special": false,
"text": " /" "text": " /"
}, },
{ {
"id": 2754, "id": 2754,
"logprob": -2.2695312, "logprob": -2.265625,
"special": false, "special": false,
"text": "api" "text": "api"
}, },
{ {
"id": 29914, "id": 29914,
"logprob": -0.03439331, "logprob": -0.034088135,
"special": false, "special": false,
"text": "/" "text": "/"
}, },
...@@ -143,31 +143,31 @@ ...@@ -143,31 +143,31 @@
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -0.36694336, "logprob": -0.36816406,
"special": false, "special": false,
"text": "1" "text": "1"
}, },
{ {
"id": 29914, "id": 29914,
"logprob": -0.013114929, "logprob": -0.013191223,
"special": false, "special": false,
"text": "/" "text": "/"
}, },
{ {
"id": 16418, "id": 16418,
"logprob": -3.1542969, "logprob": -3.15625,
"special": false, "special": false,
"text": "projects" "text": "projects"
}, },
{ {
"id": 29914, "id": 29914,
"logprob": -0.43847656, "logprob": -0.43774414,
"special": false, "special": false,
"text": "/" "text": "/"
}, },
{ {
"id": 29896, "id": 29896,
"logprob": -1.9433594, "logprob": -1.9443359,
"special": false, "special": false,
"text": "1" "text": "1"
} }
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"id": 926, "id": 926,
"logprob": -4.3554688, "logprob": -4.3554688,
"special": false, "special": false,
"text": "To" "text": " To"
}, },
{ {
"id": 18295, "id": 18295,
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"id": 16017, "id": 16017,
"logprob": -1.3505859, "logprob": -1.3505859,
"special": false, "special": false,
"text": "blue" "text": " blue"
}, },
{ {
"id": 20495, "id": 20495,
......
...@@ -15,37 +15,37 @@ ...@@ -15,37 +15,37 @@
"tokens": [ "tokens": [
{ {
"id": 259, "id": 259,
"logprob": -1.3789062, "logprob": -1.3798828,
"special": false, "special": false,
"text": "" "text": " "
}, },
{ {
"id": 39261, "id": 39261,
"logprob": -0.36279297, "logprob": -0.36328125,
"special": false, "special": false,
"text": "Because" "text": "Because"
}, },
{ {
"id": 609, "id": 609,
"logprob": -1.0966797, "logprob": -1.0947266,
"special": false, "special": false,
"text": " it" "text": " it"
}, },
{ {
"id": 339, "id": 339,
"logprob": -0.8276367, "logprob": -0.8286133,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 16017, "id": 16017,
"logprob": -1.6845703, "logprob": -1.6826172,
"special": false, "special": false,
"text": " blue" "text": " blue"
}, },
{ {
"id": 1, "id": 1,
"logprob": -0.72753906, "logprob": -0.7290039,
"special": true, "special": true,
"text": "</s>" "text": "</s>"
} }
...@@ -69,37 +69,37 @@ ...@@ -69,37 +69,37 @@
"tokens": [ "tokens": [
{ {
"id": 259, "id": 259,
"logprob": -1.3798828, "logprob": -1.3789062,
"special": false, "special": false,
"text": "" "text": " "
}, },
{ {
"id": 39261, "id": 39261,
"logprob": -0.36328125, "logprob": -0.36279297,
"special": false, "special": false,
"text": "Because" "text": "Because"
}, },
{ {
"id": 609, "id": 609,
"logprob": -1.0947266, "logprob": -1.0966797,
"special": false, "special": false,
"text": " it" "text": " it"
}, },
{ {
"id": 339, "id": 339,
"logprob": -0.8286133, "logprob": -0.8276367,
"special": false, "special": false,
"text": " is" "text": " is"
}, },
{ {
"id": 16017, "id": 16017,
"logprob": -1.6826172, "logprob": -1.6845703,
"special": false, "special": false,
"text": " blue" "text": " blue"
}, },
{ {
"id": 1, "id": 1,
"logprob": -0.7290039, "logprob": -0.72753906,
"special": true, "special": true,
"text": "</s>" "text": "</s>"
} }
...@@ -125,7 +125,7 @@ ...@@ -125,7 +125,7 @@
"id": 259, "id": 259,
"logprob": -1.3789062, "logprob": -1.3789062,
"special": false, "special": false,
"text": "" "text": " "
}, },
{ {
"id": 39261, "id": 39261,
...@@ -179,7 +179,7 @@ ...@@ -179,7 +179,7 @@
"id": 259, "id": 259,
"logprob": -1.3789062, "logprob": -1.3789062,
"special": false, "special": false,
"text": "" "text": " "
}, },
{ {
"id": 39261, "id": 39261,
......
...@@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> { ...@@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> {
sha: None, sha: None,
pipeline_tag: None, pipeline_tag: None,
}, },
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({ false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub."); tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None } HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}), }),
......
...@@ -2,7 +2,7 @@ include Makefile-transformers ...@@ -2,7 +2,7 @@ include Makefile-transformers
include Makefile-flash-att include Makefile-flash-att
unit-tests: unit-tests:
python -m pytest tests pytest -s -vv -m "not private" tests
gen-server: gen-server:
# Compile protos # Compile protos
......
import pytest
import torch
from transformers import AutoTokenizer
from text_generation_server.models import Model
def get_test_model():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
)
return model
@pytest.mark.private
def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
@pytest.mark.private
def test_decode_streaming_chinese_utf8():
model = get_test_model()
truth = "我很感谢你的热情"
all_input_ids = [
30672,
232,
193,
139,
233,
135,
162,
235,
179,
165,
30919,
30210,
234,
134,
176,
30993,
]
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
...@@ -149,7 +149,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch) ...@@ -149,7 +149,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert all([generation.generated_text is None for generation in generations]) assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations]) assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations]) assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == "" for generation in generations]) assert all([generation.token_text == " " for generation in generations])
assert generations[0].request_id == 0 assert generations[0].request_id == 0
......
...@@ -56,7 +56,7 @@ class BLOOM(CausalLM): ...@@ -56,7 +56,7 @@ class BLOOM(CausalLM):
quantize: Optional[str] = None, quantize: Optional[str] = None,
): ):
super(BLOOM, self).__init__( super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1 model_id=model_id, revision=revision, quantize=quantize
) )
@property @property
...@@ -104,14 +104,13 @@ class BLOOMSharded(BLOOM): ...@@ -104,14 +104,13 @@ class BLOOMSharded(BLOOM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
......
...@@ -35,8 +35,8 @@ class CausalLMBatch(Batch): ...@@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
offsets: List[Optional[int]] prefix_offsets: List[int]
token_offsets: List[Optional[int]] read_offsets: List[int]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
...@@ -70,8 +70,8 @@ class CausalLMBatch(Batch): ...@@ -70,8 +70,8 @@ class CausalLMBatch(Batch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
# Parse batch # Parse batch
...@@ -81,8 +81,6 @@ class CausalLMBatch(Batch): ...@@ -81,8 +81,6 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests): for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
inputs.append(r.inputs) inputs.append(r.inputs)
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
...@@ -102,6 +100,10 @@ class CausalLMBatch(Batch): ...@@ -102,6 +100,10 @@ class CausalLMBatch(Batch):
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
).to(device) ).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
...@@ -130,8 +132,8 @@ class CausalLMBatch(Batch): ...@@ -130,8 +132,8 @@ class CausalLMBatch(Batch):
past_key_values=None, past_key_values=None,
all_input_ids=list(all_input_ids), all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
...@@ -151,8 +153,8 @@ class CausalLMBatch(Batch): ...@@ -151,8 +153,8 @@ class CausalLMBatch(Batch):
# New values after filtering # New values after filtering
requests_idx_mapping = {} requests_idx_mapping = {}
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
max_input_length = 0 max_input_length = 0
...@@ -167,8 +169,8 @@ class CausalLMBatch(Batch): ...@@ -167,8 +169,8 @@ class CausalLMBatch(Batch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
keep_indices.append(idx) keep_indices.append(idx)
offsets.append(self.offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
token_offsets.append(self.token_offsets[idx]) read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx]) all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx] request_input_length = self.input_lengths[idx]
...@@ -225,8 +227,8 @@ class CausalLMBatch(Batch): ...@@ -225,8 +227,8 @@ class CausalLMBatch(Batch):
self.position_ids = position_ids self.position_ids = position_ids
self.all_input_ids = all_input_ids self.all_input_ids = all_input_ids
self.input_lengths = input_lengths self.input_lengths = input_lengths
self.offsets = offsets self.prefix_offsets = prefix_offsets
self.token_offsets = token_offsets self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length self.max_input_length = max_input_length
...@@ -251,8 +253,8 @@ class CausalLMBatch(Batch): ...@@ -251,8 +253,8 @@ class CausalLMBatch(Batch):
requests = [] requests = []
requests_idx_mapping = {} requests_idx_mapping = {}
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
...@@ -270,8 +272,8 @@ class CausalLMBatch(Batch): ...@@ -270,8 +272,8 @@ class CausalLMBatch(Batch):
for i, batch in enumerate(batches): for i, batch in enumerate(batches):
requests.extend(batch.requests) requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) prefix_offsets.extend(batch.prefix_offsets)
token_offsets.extend(batch.token_offsets) read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids) all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
...@@ -428,8 +430,8 @@ class CausalLMBatch(Batch): ...@@ -428,8 +430,8 @@ class CausalLMBatch(Batch):
past_key_values=past_key_values, past_key_values=past_key_values,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length, max_input_length=max_input_length,
...@@ -448,7 +450,6 @@ class CausalLM(Model): ...@@ -448,7 +450,6 @@ class CausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
...@@ -463,25 +464,25 @@ class CausalLM(Model): ...@@ -463,25 +464,25 @@ class CausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
self.model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_id, model_id,
revision=revision, revision=revision,
torch_dtype=dtype, torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None, device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes", load_in_8bit=quantize == "bitsandbytes",
).eval() )
tokenizer.pad_token_id = ( tokenizer.pad_token_id = (
self.model.config.pad_token_id model.config.pad_token_id
if self.model.config.pad_token_id is not None if model.config.pad_token_id is not None
else self.model.config.eos_token_id else model.config.eos_token_id
) )
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=decode_buffer,
) )
@property @property
...@@ -528,8 +529,8 @@ class CausalLM(Model): ...@@ -528,8 +529,8 @@ class CausalLM(Model):
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.prefix_offsets,
batch.token_offsets, batch.read_offsets,
logits, logits,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
...@@ -540,8 +541,8 @@ class CausalLM(Model): ...@@ -540,8 +541,8 @@ class CausalLM(Model):
for i, ( for i, (
request, request,
input_length, input_length,
offset, prefix_offset,
token_offset, read_offset,
logits, logits,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
...@@ -559,8 +560,8 @@ class CausalLM(Model): ...@@ -559,8 +560,8 @@ class CausalLM(Model):
# Generated token # Generated token
next_token_logprob = logprobs[-1, next_token_id] next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze() next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], offset, token_offset all_input_ids[:, 0], prefix_offset, read_offset
) )
# Evaluate stopping criteria # Evaluate stopping criteria
...@@ -628,8 +629,8 @@ class CausalLM(Model): ...@@ -628,8 +629,8 @@ class CausalLM(Model):
batch.input_ids[i, 0] = next_token_id batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset batch.prefix_offsets[i] = prefix_offset
batch.token_offsets[i] = token_offset batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length) batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch # We finished all generations in the batch; there is no next batch
......
...@@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch): ...@@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch # Lengths of all generations present in the batch
input_lengths: List[int] input_lengths: List[int]
offsets: List[Optional[int]] prefix_offsets: List[Optional[int]]
token_offsets: List[Optional[int]] read_offsets: List[Optional[int]]
# Generation helpers # Generation helpers
next_token_choosers: List[NextTokenChooser] next_token_choosers: List[NextTokenChooser]
...@@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch): ...@@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0 max_seqlen = 0
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
all_input_ids = [] all_input_ids = []
requests_idx_mapping = {} requests_idx_mapping = {}
...@@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch): ...@@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length) max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length) input_lengths.append(input_length)
offsets.append(None) prefix_offsets.append(0)
token_offsets.append(None) read_offsets.append(input_length)
all_input_ids.append(tokenized_input) all_input_ids.append(tokenized_input)
...@@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch): ...@@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
past_key_values=None, past_key_values=None,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=[], all_input_ids_tensor=[],
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
...@@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch): ...@@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = [] all_input_ids_tensor = []
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
...@@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch): ...@@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.append(self.all_input_ids_tensor[idx]) all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length) input_lengths.append(request_input_length)
offsets.append(self.offsets[idx]) prefix_offsets.append(self.prefix_offsets[idx])
token_offsets.append(self.token_offsets[idx]) read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx]) next_token_choosers.append(self.next_token_choosers[idx])
...@@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch): ...@@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
...@@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch): ...@@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = [] all_input_ids_tensor = []
input_lengths = [] input_lengths = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
...@@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch): ...@@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.extend(batch.all_input_ids_tensor) all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths) input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets) prefix_offsets.extend(batch.prefix_offsets)
token_offsets.extend(batch.token_offsets) read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers) next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias) stopping_criterias.extend(batch.stopping_criterias)
...@@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch): ...@@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen, max_seqlen=max_seqlen,
past_key_values=past_key_values, past_key_values=past_key_values,
input_lengths=input_lengths, input_lengths=input_lengths,
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
all_input_ids=all_input_ids, all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor, all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
...@@ -394,7 +394,6 @@ class FlashCausalLM(Model): ...@@ -394,7 +394,6 @@ class FlashCausalLM(Model):
model_id: str, model_id: str,
revision: Optional[str] = None, revision: Optional[str] = None,
quantize: Optional[str] = None, quantize: Optional[str] = None,
decode_buffer: int = 3,
): ):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda") device = torch.device("cuda")
...@@ -405,23 +404,19 @@ class FlashCausalLM(Model): ...@@ -405,23 +404,19 @@ class FlashCausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left" model_id, revision=revision, padding_side="left", truncation_side="left"
) )
self.model = ( model = model_cls.from_pretrained(
model_cls.from_pretrained( model_id,
model_id, revision=revision,
revision=revision, torch_dtype=dtype,
torch_dtype=dtype, load_in_8bit=quantize == "bitsandbytes",
load_in_8bit=quantize == "bitsandbytes", ).to(device)
)
.eval()
.to(device)
)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=decode_buffer,
) )
@property @property
...@@ -645,8 +640,8 @@ class FlashCausalLM(Model): ...@@ -645,8 +640,8 @@ class FlashCausalLM(Model):
iterator = zip( iterator = zip(
batch.requests, batch.requests,
batch.input_lengths, batch.input_lengths,
batch.offsets, batch.prefix_offsets,
batch.token_offsets, batch.read_offsets,
batch.next_token_choosers, batch.next_token_choosers,
batch.stopping_criterias, batch.stopping_criterias,
batch.all_input_ids, batch.all_input_ids,
...@@ -659,8 +654,8 @@ class FlashCausalLM(Model): ...@@ -659,8 +654,8 @@ class FlashCausalLM(Model):
for i, ( for i, (
request, request,
input_length, input_length,
offset, prefix_offset,
token_offset, read_offset,
next_token_chooser, next_token_chooser,
stopping_criteria, stopping_criteria,
all_input_ids, all_input_ids,
...@@ -675,10 +670,10 @@ class FlashCausalLM(Model): ...@@ -675,10 +670,10 @@ class FlashCausalLM(Model):
all_input_ids.append(next_token_id) all_input_ids.append(next_token_id)
# Generated token # Generated token
next_token_text, offset, token_offset = self.decode_token( next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids, all_input_ids,
offset, prefix_offset,
token_offset, read_offset,
) )
# Evaluate stopping criteria # Evaluate stopping criteria
...@@ -744,8 +739,8 @@ class FlashCausalLM(Model): ...@@ -744,8 +739,8 @@ class FlashCausalLM(Model):
# Update values # Update values
batch.input_lengths[i] = new_input_length batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset batch.prefix_offsets[i] = prefix_offset
batch.token_offsets[i] = token_offset batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1 batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length cumulative_length += input_length
......
...@@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM): ...@@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config) model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype) self.load_weights(model, filenames, quantize, device, dtype)
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
...@@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama): ...@@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
......
...@@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX): ...@@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
......
...@@ -67,14 +67,13 @@ class FlashSantacoder(FlashCausalLM): ...@@ -67,14 +67,13 @@ class FlashSantacoder(FlashCausalLM):
dtype, dtype,
config.architectures[0].startswith("GPT2"), config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
decode_buffer=1,
) )
@staticmethod @staticmethod
...@@ -213,16 +212,15 @@ class FlashSantacoderSharded(FlashSantacoder): ...@@ -213,16 +212,15 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size, world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"), transpose=config.architectures[0].startswith("GPT2"),
) )
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__( super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=False, requires_padding=False,
dtype=dtype, dtype=dtype,
device=device, device=device,
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
decode_buffer=1,
) )
@staticmethod @staticmethod
......
...@@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs = [] inputs = []
next_token_choosers = [] next_token_choosers = []
stopping_criterias = [] stopping_criterias = []
offsets = [] prefix_offsets = []
token_offsets = [] read_offsets = []
requests_idx_mapping = {} requests_idx_mapping = {}
# Parse batch # Parse batch
...@@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic # Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs)) inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb( stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer r.stopping_parameters, tokenizer
...@@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
truncation=True, truncation=True,
max_length=max_truncation, max_length=max_truncation,
).to(device) ).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1) input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max() max_input_length = input_lengths.max()
...@@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch): ...@@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
past_key_values=None, past_key_values=None,
all_input_ids=list(all_input_ids), all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(), input_lengths=input_lengths.tolist(),
offsets=offsets, prefix_offsets=prefix_offsets,
token_offsets=token_offsets, read_offsets=read_offsets,
next_token_choosers=next_token_choosers, next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias, stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(), max_input_length=max_input_length.item(),
...@@ -231,9 +233,9 @@ class GalacticaSharded(Galactica): ...@@ -231,9 +233,9 @@ class GalacticaSharded(Galactica):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
......
...@@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM): ...@@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM):
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
) )
self.model = model.eval()
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__( super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
requires_padding=True, requires_padding=True,
dtype=dtype, dtype=dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment