Unverified Commit 5a582261 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix decode token (#334)



Fixes #333

---------
Co-authored-by: default avatarNicolas Patry <patry.nicolas@protonmail.com>
parent dbdc587d
......@@ -213,12 +213,13 @@ jobs:
sudo mount /dev/nvme1n1 ${{ env.DOCKER_VOLUME }}
- name: Install
run: |
pip install pytest-xdist
make install-integration-tests
- name: Run tests
run: |
export DOCKER_IMAGE=registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ env.GITHUB_SHA_SHORT }}
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv integration-tests
pytest -s -vv -n 2 --dist loadfile integration-tests
stop-runner:
name: Stop self-hosted EC2 runner
......
......@@ -66,7 +66,8 @@ jobs:
- name: Run server tests
run: |
pip install pytest
make python-server-tests
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
pytest -s -vv server/tests
- name: Run Rust fmt
run: |
cargo fmt --check
......
......@@ -31,7 +31,7 @@ update-integration-tests: install-integration-tests
pytest -s -vv --snapshot-update integration-tests
python-server-tests:
HF_HUB_ENABLE_HF_TRANSFER=1 pytest server/tests
HF_HUB_ENABLE_HF_TRANSFER=1 pytest -s -vv -m "not private" server/tests
python-client-tests:
pytest clients/python/tests
......
import sys
import subprocess
import contextlib
import pytest
......@@ -7,6 +8,7 @@ import docker
import json
import math
import time
import random
from docker.errors import NotFound
from typing import Optional, List, Dict
......@@ -205,10 +207,12 @@ def launcher(event_loop):
def local_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
):
port = 9999
master_port = 19999
port = random.randint(8000, 10_000)
master_port = random.randint(10_000, 20_000)
shard_uds_path = f"/tmp/{model_id.replace('/', '--')}-server"
shard_uds_path = (
f"/tmp/tgi-tests-{model_id.split('/')[-1]}-{num_shard}-{quantize}-server"
)
args = [
"text-generation-launcher",
......@@ -236,7 +240,7 @@ def launcher(event_loop):
process.wait(60)
launcher_output = process.stdout.read().decode("utf-8")
print(launcher_output)
print(launcher_output, file=sys.stderr)
process.stdout.close()
process.stderr.close()
......@@ -245,7 +249,7 @@ def launcher(event_loop):
def docker_launcher(
model_id: str, num_shard: Optional[int] = None, quantize: Optional[str] = None
):
port = 9999
port = random.randint(8000, 10_000)
args = ["--model-id", model_id, "--env"]
......@@ -298,7 +302,7 @@ def launcher(event_loop):
pass
container_output = container.logs().decode("utf-8")
print(container_output)
print(container_output, file=sys.stderr)
container.remove()
......
......@@ -25,25 +25,25 @@
"tokens": [
{
"id": 363,
"logprob": -1.5322266,
"logprob": -1.5380859,
"special": false,
"text": " for"
},
{
"id": 847,
"logprob": -2.5585938,
"logprob": -2.5859375,
"special": false,
"text": " /"
},
{
"id": 2754,
"logprob": -2.265625,
"logprob": -2.2695312,
"special": false,
"text": "api"
},
{
"id": 29914,
"logprob": -0.034088135,
"logprob": -0.03439331,
"special": false,
"text": "/"
},
......@@ -55,31 +55,31 @@
},
{
"id": 29896,
"logprob": -0.36816406,
"logprob": -0.36694336,
"special": false,
"text": "1"
},
{
"id": 29914,
"logprob": -0.013191223,
"logprob": -0.013114929,
"special": false,
"text": "/"
},
{
"id": 16418,
"logprob": -3.15625,
"logprob": -3.1542969,
"special": false,
"text": "projects"
},
{
"id": 29914,
"logprob": -0.43774414,
"logprob": -0.43847656,
"special": false,
"text": "/"
},
{
"id": 29896,
"logprob": -1.9443359,
"logprob": -1.9433594,
"special": false,
"text": "1"
}
......@@ -113,25 +113,25 @@
"tokens": [
{
"id": 363,
"logprob": -1.5380859,
"logprob": -1.5322266,
"special": false,
"text": " for"
},
{
"id": 847,
"logprob": -2.5859375,
"logprob": -2.5585938,
"special": false,
"text": " /"
},
{
"id": 2754,
"logprob": -2.2695312,
"logprob": -2.265625,
"special": false,
"text": "api"
},
{
"id": 29914,
"logprob": -0.03439331,
"logprob": -0.034088135,
"special": false,
"text": "/"
},
......@@ -143,31 +143,31 @@
},
{
"id": 29896,
"logprob": -0.36694336,
"logprob": -0.36816406,
"special": false,
"text": "1"
},
{
"id": 29914,
"logprob": -0.013114929,
"logprob": -0.013191223,
"special": false,
"text": "/"
},
{
"id": 16418,
"logprob": -3.1542969,
"logprob": -3.15625,
"special": false,
"text": "projects"
},
{
"id": 29914,
"logprob": -0.43847656,
"logprob": -0.43774414,
"special": false,
"text": "/"
},
{
"id": 29896,
"logprob": -1.9433594,
"logprob": -1.9443359,
"special": false,
"text": "1"
}
......
......@@ -16,7 +16,7 @@
"id": 926,
"logprob": -4.3554688,
"special": false,
"text": "To"
"text": " To"
},
{
"id": 18295,
......
......@@ -16,7 +16,7 @@
"id": 16017,
"logprob": -1.3505859,
"special": false,
"text": "blue"
"text": " blue"
},
{
"id": 20495,
......
......@@ -15,37 +15,37 @@
"tokens": [
{
"id": 259,
"logprob": -1.3789062,
"logprob": -1.3798828,
"special": false,
"text": ""
"text": " "
},
{
"id": 39261,
"logprob": -0.36279297,
"logprob": -0.36328125,
"special": false,
"text": "Because"
},
{
"id": 609,
"logprob": -1.0966797,
"logprob": -1.0947266,
"special": false,
"text": " it"
},
{
"id": 339,
"logprob": -0.8276367,
"logprob": -0.8286133,
"special": false,
"text": " is"
},
{
"id": 16017,
"logprob": -1.6845703,
"logprob": -1.6826172,
"special": false,
"text": " blue"
},
{
"id": 1,
"logprob": -0.72753906,
"logprob": -0.7290039,
"special": true,
"text": "</s>"
}
......@@ -69,37 +69,37 @@
"tokens": [
{
"id": 259,
"logprob": -1.3798828,
"logprob": -1.3789062,
"special": false,
"text": ""
"text": " "
},
{
"id": 39261,
"logprob": -0.36328125,
"logprob": -0.36279297,
"special": false,
"text": "Because"
},
{
"id": 609,
"logprob": -1.0947266,
"logprob": -1.0966797,
"special": false,
"text": " it"
},
{
"id": 339,
"logprob": -0.8286133,
"logprob": -0.8276367,
"special": false,
"text": " is"
},
{
"id": 16017,
"logprob": -1.6826172,
"logprob": -1.6845703,
"special": false,
"text": " blue"
},
{
"id": 1,
"logprob": -0.7290039,
"logprob": -0.72753906,
"special": true,
"text": "</s>"
}
......@@ -125,7 +125,7 @@
"id": 259,
"logprob": -1.3789062,
"special": false,
"text": ""
"text": " "
},
{
"id": 39261,
......@@ -179,7 +179,7 @@
"id": 259,
"logprob": -1.3789062,
"special": false,
"text": ""
"text": " "
},
{
"id": 39261,
......
......@@ -146,7 +146,7 @@ fn main() -> Result<(), std::io::Error> {
sha: None,
pipeline_tag: None,
},
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or({
false => get_model_info(&tokenizer_name, &revision, authorization_token).await.unwrap_or_else(|| {
tracing::warn!("Could not retrieve model info from the Hugging Face hub.");
HubModelInfo { model_id: tokenizer_name.to_string(), sha: None, pipeline_tag: None }
}),
......
......@@ -2,7 +2,7 @@ include Makefile-transformers
include Makefile-flash-att
unit-tests:
python -m pytest tests
pytest -s -vv -m "not private" tests
gen-server:
# Compile protos
......
import pytest
import torch
from transformers import AutoTokenizer
from text_generation_server.models import Model
def get_test_model():
class TestModel(Model):
def batch_type(self):
raise NotImplementedError
def generate_token(self, batch):
raise NotImplementedError
tokenizer = AutoTokenizer.from_pretrained("huggingface/llama-7b")
model = TestModel(
torch.nn.Linear(1, 1), tokenizer, False, torch.float32, torch.device("cpu")
)
return model
@pytest.mark.private
def test_decode_streaming_english_spaces():
model = get_test_model()
truth = "Hello here, this is a simple test"
all_input_ids = [15043, 1244, 29892, 445, 338, 263, 2560, 1243]
assert (
all_input_ids == model.tokenizer(truth, add_special_tokens=False)["input_ids"]
)
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
@pytest.mark.private
def test_decode_streaming_chinese_utf8():
model = get_test_model()
truth = "我很感谢你的热情"
all_input_ids = [
30672,
232,
193,
139,
233,
135,
162,
235,
179,
165,
30919,
30210,
234,
134,
176,
30993,
]
decoded_text = ""
offset = 0
token_offset = 0
for i in range(len(all_input_ids)):
text, offset, token_offset = model.decode_token(
all_input_ids[: i + 1], offset, token_offset
)
decoded_text += text
assert decoded_text == truth
......@@ -149,7 +149,7 @@ def test_seq2seq_lm_generate_token(default_seq2seq_lm, default_seq2seq_lm_batch)
assert all([generation.generated_text is None for generation in generations])
assert all([len(generation.prefill_tokens) == 1 for generation in generations])
assert all([generation.token_id.item() == 259 for generation in generations])
assert all([generation.token_text == "" for generation in generations])
assert all([generation.token_text == " " for generation in generations])
assert generations[0].request_id == 0
......
......@@ -56,7 +56,7 @@ class BLOOM(CausalLM):
quantize: Optional[str] = None,
):
super(BLOOM, self).__init__(
model_id=model_id, revision=revision, quantize=quantize, decode_buffer=1
model_id=model_id, revision=revision, quantize=quantize
)
@property
......@@ -104,14 +104,13 @@ class BLOOMSharded(BLOOM):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
decode_buffer=1,
rank=rank,
world_size=world_size,
)
......
......@@ -35,8 +35,8 @@ class CausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
prefix_offsets: List[int]
read_offsets: List[int]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
......@@ -70,8 +70,8 @@ class CausalLMBatch(Batch):
inputs = []
next_token_choosers = []
stopping_criterias = []
offsets = []
token_offsets = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
......@@ -81,8 +81,6 @@ class CausalLMBatch(Batch):
for i, r in enumerate(pb.requests):
requests_idx_mapping[r.id] = i
inputs.append(r.inputs)
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
......@@ -102,6 +100,10 @@ class CausalLMBatch(Batch):
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
......@@ -130,8 +132,8 @@ class CausalLMBatch(Batch):
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
offsets=offsets,
token_offsets=token_offsets,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
......@@ -151,8 +153,8 @@ class CausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
max_input_length = 0
......@@ -167,8 +169,8 @@ class CausalLMBatch(Batch):
requests_idx_mapping[r.id] = i
keep_indices.append(idx)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
all_input_ids.append(self.all_input_ids[idx])
request_input_length = self.input_lengths[idx]
......@@ -225,8 +227,8 @@ class CausalLMBatch(Batch):
self.position_ids = position_ids
self.all_input_ids = all_input_ids
self.input_lengths = input_lengths
self.offsets = offsets
self.token_offsets = token_offsets
self.prefix_offsets = prefix_offsets
self.read_offsets = read_offsets
self.next_token_choosers = next_token_choosers
self.stopping_criterias = stopping_criterias
self.max_input_length = max_input_length
......@@ -251,8 +253,8 @@ class CausalLMBatch(Batch):
requests = []
requests_idx_mapping = {}
input_lengths = []
offsets = []
token_offsets = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
next_token_choosers = []
stopping_criterias = []
......@@ -270,8 +272,8 @@ class CausalLMBatch(Batch):
for i, batch in enumerate(batches):
requests.extend(batch.requests)
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
all_input_ids.extend(batch.all_input_ids)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
......@@ -428,8 +430,8 @@ class CausalLMBatch(Batch):
past_key_values=past_key_values,
all_input_ids=all_input_ids,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length,
......@@ -448,7 +450,6 @@ class CausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
device = torch.device("cuda")
......@@ -463,25 +464,25 @@ class CausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
self.model = AutoModelForCausalLM.from_pretrained(
model = AutoModelForCausalLM.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
load_in_8bit=quantize == "bitsandbytes",
).eval()
)
tokenizer.pad_token_id = (
self.model.config.pad_token_id
if self.model.config.pad_token_id is not None
else self.model.config.eos_token_id
model.config.pad_token_id
if model.config.pad_token_id is not None
else model.config.eos_token_id
)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
device=device,
decode_buffer=decode_buffer,
)
@property
......@@ -528,8 +529,8 @@ class CausalLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
batch.offsets,
batch.token_offsets,
batch.prefix_offsets,
batch.read_offsets,
logits,
batch.next_token_choosers,
batch.stopping_criterias,
......@@ -540,8 +541,8 @@ class CausalLM(Model):
for i, (
request,
input_length,
offset,
token_offset,
prefix_offset,
read_offset,
logits,
next_token_chooser,
stopping_criteria,
......@@ -559,8 +560,8 @@ class CausalLM(Model):
# Generated token
next_token_logprob = logprobs[-1, next_token_id]
next_token_id_squeezed = next_token_id.squeeze()
next_token_text, offset, token_offset = self.decode_token(
all_input_ids[:, 0], offset, token_offset
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids[:, 0], prefix_offset, read_offset
)
# Evaluate stopping criteria
......@@ -628,8 +629,8 @@ class CausalLM(Model):
batch.input_ids[i, 0] = next_token_id
batch.all_input_ids[i] = all_input_ids
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.max_input_length = max(batch.max_input_length, new_input_length)
# We finished all generations in the batch; there is no next batch
......
......@@ -52,8 +52,8 @@ class FlashCausalLMBatch(Batch):
# Lengths of all generations present in the batch
input_lengths: List[int]
offsets: List[Optional[int]]
token_offsets: List[Optional[int]]
prefix_offsets: List[Optional[int]]
read_offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
......@@ -82,8 +82,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = 0
input_lengths = []
offsets = []
token_offsets = []
prefix_offsets = []
read_offsets = []
all_input_ids = []
requests_idx_mapping = {}
......@@ -108,8 +108,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen = max(max_seqlen, input_length)
input_lengths.append(input_length)
offsets.append(None)
token_offsets.append(None)
prefix_offsets.append(0)
read_offsets.append(input_length)
all_input_ids.append(tokenized_input)
......@@ -151,8 +151,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen,
past_key_values=None,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=[],
next_token_choosers=next_token_choosers,
......@@ -190,8 +190,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
stopping_criterias = []
......@@ -222,8 +222,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
offsets.append(self.offsets[idx])
token_offsets.append(self.token_offsets[idx])
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
......@@ -269,8 +269,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
......@@ -302,8 +302,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor = []
input_lengths = []
offsets = []
token_offsets = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
stopping_criterias = []
......@@ -347,8 +347,8 @@ class FlashCausalLMBatch(Batch):
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths)
offsets.extend(batch.offsets)
token_offsets.extend(batch.token_offsets)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
stopping_criterias.extend(batch.stopping_criterias)
......@@ -374,8 +374,8 @@ class FlashCausalLMBatch(Batch):
max_seqlen=max_seqlen,
past_key_values=past_key_values,
input_lengths=input_lengths,
offsets=offsets,
token_offsets=token_offsets,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
......@@ -394,7 +394,6 @@ class FlashCausalLM(Model):
model_id: str,
revision: Optional[str] = None,
quantize: Optional[str] = None,
decode_buffer: int = 3,
):
if torch.cuda.is_available():
device = torch.device("cuda")
......@@ -405,23 +404,19 @@ class FlashCausalLM(Model):
tokenizer = AutoTokenizer.from_pretrained(
model_id, revision=revision, padding_side="left", truncation_side="left"
)
self.model = (
model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
)
.eval()
.to(device)
)
model = model_cls.from_pretrained(
model_id,
revision=revision,
torch_dtype=dtype,
load_in_8bit=quantize == "bitsandbytes",
).to(device)
super(FlashCausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
decode_buffer=decode_buffer,
)
@property
......@@ -645,8 +640,8 @@ class FlashCausalLM(Model):
iterator = zip(
batch.requests,
batch.input_lengths,
batch.offsets,
batch.token_offsets,
batch.prefix_offsets,
batch.read_offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
......@@ -659,8 +654,8 @@ class FlashCausalLM(Model):
for i, (
request,
input_length,
offset,
token_offset,
prefix_offset,
read_offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
......@@ -675,10 +670,10 @@ class FlashCausalLM(Model):
all_input_ids.append(next_token_id)
# Generated token
next_token_text, offset, token_offset = self.decode_token(
next_token_text, prefix_offset, read_offset = self.decode_token(
all_input_ids,
offset,
token_offset,
prefix_offset,
read_offset,
)
# Evaluate stopping criteria
......@@ -744,8 +739,8 @@ class FlashCausalLM(Model):
# Update values
batch.input_lengths[i] = new_input_length
batch.offsets[i] = offset
batch.token_offsets[i] = token_offset
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
......
......@@ -64,9 +64,9 @@ class FlashLlama(FlashCausalLM):
model = FlashLlamaForCausalLM(config)
self.load_weights(model, filenames, quantize, device, dtype)
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
......@@ -189,9 +189,9 @@ class FlashLlamaSharded(FlashLlama):
rank=rank,
world_size=world_size,
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
......
......@@ -73,9 +73,9 @@ class FlashNeoXSharded(FlashNeoX):
rank=rank,
world_size=world_size,
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
......
......@@ -67,14 +67,13 @@ class FlashSantacoder(FlashCausalLM):
dtype,
config.architectures[0].startswith("GPT2"),
)
self.model = model.eval().to(device)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
decode_buffer=1,
)
@staticmethod
......@@ -213,16 +212,15 @@ class FlashSantacoderSharded(FlashSantacoder):
world_size=world_size,
transpose=config.architectures[0].startswith("GPT2"),
)
self.model = model.eval().to(device)
torch.distributed.barrier(group=self.process_group)
super(FlashCausalLM, self).__init__(
model=model.to(device),
tokenizer=tokenizer,
requires_padding=False,
dtype=dtype,
device=device,
rank=rank,
world_size=world_size,
decode_buffer=1,
)
@staticmethod
......
......@@ -94,8 +94,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
inputs = []
next_token_choosers = []
stopping_criterias = []
offsets = []
token_offsets = []
prefix_offsets = []
read_offsets = []
requests_idx_mapping = {}
# Parse batch
......@@ -106,8 +106,6 @@ class GalacticaCausalLMBatch(CausalLMBatch):
requests_idx_mapping[r.id] = i
# Add escape_custom_split_sequence to the CausalLMBatch logic
inputs.append(escape_custom_split_sequence(r.inputs))
offsets.append(None)
token_offsets.append(None)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
......@@ -127,6 +125,10 @@ class GalacticaCausalLMBatch(CausalLMBatch):
truncation=True,
max_length=max_truncation,
).to(device)
for _ in pb.requests:
input_len = tokenized_inputs["input_ids"].shape[1]
prefix_offsets.append(0)
read_offsets.append(input_len)
input_lengths = tokenized_inputs["attention_mask"].sum(1)
max_input_length = input_lengths.max()
......@@ -155,8 +157,8 @@ class GalacticaCausalLMBatch(CausalLMBatch):
past_key_values=None,
all_input_ids=list(all_input_ids),
input_lengths=input_lengths.tolist(),
offsets=offsets,
token_offsets=token_offsets,
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
next_token_choosers=next_token_choosers,
stopping_criterias=stopping_criterias,
max_input_length=max_input_length.item(),
......@@ -231,9 +233,9 @@ class GalacticaSharded(Galactica):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
......
......@@ -70,9 +70,9 @@ class GPTNeoxSharded(CausalLM):
rank=rank,
world_size=world_size,
)
self.model = model.eval()
torch.distributed.barrier(group=self.process_group)
super(CausalLM, self).__init__(
model=model,
tokenizer=tokenizer,
requires_padding=True,
dtype=dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment