Unverified Commit 62f91f78 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

feat(server): support vectorized warpers in flash causal lm (#317)


Co-authored-by: default avatarJoel Lamy-Poirier <joel.lamy-poirier@servicenow.com>
parent 951930fb
......@@ -34,65 +34,65 @@
"tokens": [
{
"id": 408,
"logprob": -1.9267578,
"logprob": -0.07891846,
"special": false,
"text": " que"
},
{
"id": 20288,
"logprob": -2.9257812,
"id": 366,
"logprob": -1.2939453,
"special": false,
"text": " l'on"
"text": " la"
},
{
"id": 22255,
"logprob": -2.8964844,
"id": 8769,
"logprob": -0.3708496,
"special": false,
"text": " trouve"
"text": " personne"
},
{
"id": 1622,
"logprob": -1.1083984,
"id": 1479,
"logprob": -2.2871094,
"special": false,
"text": " une"
"text": " qui"
},
{
"id": 187079,
"logprob": -7.796875,
"id": 2997,
"logprob": -0.8671875,
"special": false,
"text": " posture"
"text": " vous"
},
{
"id": 501,
"logprob": -5.390625,
"id": 35977,
"logprob": -1.5097656,
"special": false,
"text": " par"
"text": " suit"
},
{
"id": 8741,
"logprob": -0.34936523,
"id": 21558,
"logprob": -0.07891846,
"special": false,
"text": " rapport"
"text": " ait"
},
{
"id": 693,
"logprob": 0.0,
"id": 447,
"logprob": -0.12695312,
"special": false,
"text": " à"
"text": " un"
},
{
"id": 366,
"logprob": -2.3378906,
"id": 78606,
"logprob": -2.21875,
"special": false,
"text": " la"
"text": " profil"
},
{
"id": 36503,
"logprob": -3.6640625,
"id": 3899,
"logprob": -1.3535156,
"special": false,
"text": " pratique"
"text": " bien"
}
]
},
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que l'on trouve une posture par rapport à la pratique"
"generated_text": "Pour déguster un ortolan, il faut tout d'abord que la personne qui vous suit ait un profil bien"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "stop_sequence",
"generated_tokens": 5,
"prefill": [
{
"id": 1,
......@@ -24,65 +24,35 @@
"tokens": [
{
"id": 5229,
"logprob": -3.3085938,
"logprob": -2.5683594,
"special": false,
"text": " failed"
},
{
"id": 363,
"logprob": -3.984375,
"special": false,
"text": " for"
},
{
"id": 5641,
"logprob": -6.53125,
"special": false,
"text": " IP"
},
{
"id": 16428,
"logprob": -3.1835938,
"special": false,
"text": " Address"
},
{
"id": 29901,
"logprob": -1.2324219,
"logprob": -0.45336914,
"special": false,
"text": ":"
},
{
"id": 525,
"logprob": -2.6855469,
"special": false,
"text": " '"
},
{
"id": 8516,
"logprob": -7.1601562,
"special": false,
"text": "None"
},
{
"id": 4286,
"logprob": -2.4433594,
"id": 4829,
"logprob": -1.8408203,
"special": false,
"text": "'."
"text": " Error"
},
{
"id": 13,
"logprob": -0.06530762,
"id": 297,
"logprob": -1.0556641,
"special": false,
"text": "\n"
"text": " in"
},
{
"id": 294,
"logprob": -7.953125,
"id": 1243,
"logprob": 0.0,
"special": false,
"text": "as"
"text": " test"
}
]
},
"generated_text": "Test requestfailed for IP Address: 'None'.\nas"
"generated_text": "Test requestfailed: Error in test"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "eos_token",
"generated_tokens": 12,
"finish_reason": "length",
"generated_tokens": 60,
"prefill": [
{
"id": 589,
......@@ -29,77 +29,365 @@
"tokens": [
{
"id": 2262,
"logprob": -0.7451172,
"logprob": -0.042999268,
"special": false,
"text": "():"
},
{
"id": 284,
"logprob": -0.21325684,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 5741,
"logprob": -5.734375,
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 10896,
"logprob": -0.3659668,
"special": false,
"text": " World"
},
{
"id": 657,
"logprob": -0.49804688,
"special": false,
"text": "\")"
},
{
"id": 203,
"logprob": -0.11279297,
"special": false,
"text": " logging"
"text": "\n"
},
{
"id": 32,
"id": 203,
"logprob": 0.0,
"special": false,
"text": "."
"text": "\n"
},
{
"id": 1338,
"logprob": -0.3232422,
"id": 589,
"logprob": -0.20141602,
"special": false,
"text": "info"
"text": "def"
},
{
"id": 463,
"logprob": -1.0380859,
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": -0.051635742,
"special": false,
"text": "name"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "('"
"text": "name"
},
{
"id": 711,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": -0.16027832,
"special": false,
"text": "(\""
},
{
"id": 8279,
"logprob": -0.8378906,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 313,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 27,
"logprob": 0.0,
"special": false,
"text": ")"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 7656,
"logprob": 0.0,
"special": false,
"text": "hello"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 81,
"logprob": 0.0,
"special": false,
"text": "_"
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 426,
"logprob": 0.0,
"special": false,
"text": "name"
},
{
"id": 30,
"logprob": -1.9501953,
"logprob": 0.0,
"special": false,
"text": ","
},
{
"id": 10896,
"logprob": -1.3476562,
"id": 11442,
"logprob": 0.0,
"special": false,
"text": " World"
"text": " age"
},
{
"id": 711,
"logprob": 0.0,
"special": false,
"text": "):"
},
{
"id": 284,
"logprob": 0.0,
"special": false,
"text": "\n "
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
},
{
"id": 440,
"logprob": 0.0,
"special": false,
"text": "(\""
},
{
"id": 683,
"logprob": -1.796875,
"id": 8279,
"logprob": 0.0,
"special": false,
"text": "Hello"
},
{
"id": 313,
"logprob": 0.0,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 636,
"logprob": 0.0,
"special": false,
"text": " name"
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 313,
"logprob": -0.6328125,
"special": false,
"text": " \""
},
{
"id": 313,
"logprob": -1.7011719,
"special": false,
"text": " \""
},
{
"id": 474,
"logprob": 0.0,
"special": false,
"text": " +"
},
{
"id": 596,
"logprob": 0.0,
"special": false,
"text": " str"
},
{
"id": 26,
"logprob": 0.0,
"special": false,
"text": "("
},
{
"id": 381,
"logprob": 0.0,
"special": false,
"text": "age"
},
{
"id": 490,
"logprob": 0.0,
"special": false,
"text": "')"
"text": "))"
},
{
"id": 203,
"logprob": -0.9873047,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 0,
"logprob": -0.7495117,
"special": true,
"text": "<|endoftext|>"
"id": 203,
"logprob": 0.0,
"special": false,
"text": "\n"
},
{
"id": 589,
"logprob": 0.0,
"special": false,
"text": "def"
},
{
"id": 1459,
"logprob": 0.0,
"special": false,
"text": " print"
}
]
},
"generated_text": "():\n logging.info('Hello, World')\n<|endoftext|>"
"generated_text": "():\n print(\"Hello World\")\n\ndef print_hello_name(name):\n print(\"Hello \" + name)\n\ndef print_hello_name_age(name, age):\n print(\"Hello \" + name + \" \" + str(age))\n\ndef print"
}
{
"details": {
"best_of_sequences": null,
"finish_reason": "length",
"generated_tokens": 10,
"finish_reason": "eos_token",
"generated_tokens": 9,
"prefill": [
{
"id": 0,
......@@ -14,65 +14,59 @@
"tokens": [
{
"id": 16017,
"logprob": -1.3505859,
"logprob": -0.30908203,
"special": false,
"text": " blue"
},
{
"id": 20495,
"logprob": -0.50439453,
"logprob": 0.0,
"special": false,
"text": " sky"
},
{
"id": 259,
"logprob": -1.2011719,
"logprob": -0.28271484,
"special": false,
"text": " "
},
{
"id": 15484,
"logprob": -2.8378906,
"logprob": -1.7929688,
"special": false,
"text": "appear"
},
{
"id": 345,
"logprob": -0.87597656,
"logprob": -0.8935547,
"special": false,
"text": "ed"
},
{
"id": 288,
"logprob": -1.8447266,
"id": 281,
"logprob": 0.0,
"special": false,
"text": " to"
"text": " in"
},
{
"id": 35622,
"logprob": -7.1445312,
"id": 287,
"logprob": 0.0,
"special": false,
"text": " cloud"
"text": " the"
},
{
"id": 263,
"logprob": -1.2929688,
"special": false,
"text": "s"
},
{
"id": 14701,
"logprob": -3.0761719,
"id": 20495,
"logprob": -0.32299805,
"special": false,
"text": " above"
"text": " sky"
},
{
"id": 751,
"logprob": -4.4375,
"special": false,
"text": " all"
"id": 1,
"logprob": 0.0,
"special": true,
"text": "</s>"
}
]
},
"generated_text": "Why is the sky blue?blue sky appeared to clouds above all"
"generated_text": "Why is the sky blue?blue sky appeared in the sky"
}
......@@ -40,7 +40,7 @@ async def test_flash_llama_all_params(flash_llama, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 10
assert response.details.generated_tokens == 5
assert response == response_snapshot
......
......@@ -29,7 +29,7 @@ async def test_flash_starcoder_default_params(flash_starcoder, response_snapshot
"def print_hello", max_new_tokens=60, temperature=0.2, top_p=0.95, seed=0
)
assert response.details.generated_tokens == 12
assert response.details.generated_tokens == 60
assert response == response_snapshot
......
......@@ -43,7 +43,7 @@ async def test_mt0_base_all_params(mt0_base, response_snapshot):
seed=0,
)
assert response.details.generated_tokens == 10
assert response.details.generated_tokens == 9
assert response == response_snapshot
......
......@@ -38,7 +38,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_bloom_batch(default_pb_batch, bloom_560m_tokenizer):
return BloomCausalLMBatch.from_pb(
default_pb_batch, bloom_560m_tokenizer, torch.device("cpu")
default_pb_batch, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
......@@ -52,7 +52,7 @@ def default_multi_requests_bloom_batch(default_pb_request, bloom_560m_tokenizer)
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return BloomCausalLMBatch.from_pb(
batch_pb, bloom_560m_tokenizer, torch.device("cpu")
batch_pb, bloom_560m_tokenizer, torch.float32, torch.device("cpu")
)
......@@ -286,7 +286,9 @@ def test_batch_concatenate(
== default_multi_requests_bloom_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_bloom_batch.stopping_criterias[0].max_new_tokens
......
......@@ -38,7 +38,9 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer):
return CausalLMBatch.from_pb(default_pb_batch, gpt2_tokenizer, torch.device("cpu"))
return CausalLMBatch.from_pb(
default_pb_batch, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
@pytest.fixture
......@@ -50,7 +52,9 @@ def default_multi_requests_causal_lm_batch(default_pb_request, gpt2_tokenizer):
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=1, requests=[req_0, req_1], size=2)
return CausalLMBatch.from_pb(batch_pb, gpt2_tokenizer, torch.device("cpu"))
return CausalLMBatch.from_pb(
batch_pb, gpt2_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_causal_lm_batch):
......@@ -285,7 +289,9 @@ def test_batch_concatenate(
== default_multi_requests_causal_lm_batch.stopping_criterias[1].max_new_tokens
)
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
for _ in range(
default_causal_lm_batch.stopping_criterias[0].max_new_tokens
......
......@@ -45,7 +45,10 @@ def default_fim_pb_batch(default_fim_pb_request):
@pytest.mark.skip
def test_santacoder_generate_token_completion(default_santacoder, default_pb_batch):
batch = CausalLMBatch.from_pb(
default_pb_batch, default_santacoder.tokenizer, default_santacoder.device
default_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
......@@ -70,7 +73,10 @@ def test_fim_santacoder_generate_token_completion(
default_santacoder, default_fim_pb_batch
):
batch = CausalLMBatch.from_pb(
default_fim_pb_batch, default_santacoder.tokenizer, default_santacoder.device
default_fim_pb_batch,
default_santacoder.tokenizer,
default_santacoder.dtype,
default_santacoder.device,
)
next_batch = batch
......
......@@ -42,7 +42,7 @@ def default_pb_batch(default_pb_request):
@pytest.fixture
def default_seq2seq_lm_batch(default_pb_batch, mt0_small_tokenizer):
return Seq2SeqLMBatch.from_pb(
default_pb_batch, mt0_small_tokenizer, torch.device("cpu")
default_pb_batch, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
......@@ -55,7 +55,9 @@ def default_multi_requests_seq2seq_lm_batch(default_pb_request, mt0_small_tokeni
req_1.stopping_parameters.max_new_tokens = 5
batch_pb = generate_pb2.Batch(id=0, requests=[req_0, req_1], size=2)
return Seq2SeqLMBatch.from_pb(batch_pb, mt0_small_tokenizer, torch.device("cpu"))
return Seq2SeqLMBatch.from_pb(
batch_pb, mt0_small_tokenizer, torch.float32, torch.device("cpu")
)
def test_batch_from_pb(default_pb_batch, default_seq2seq_lm_batch):
......@@ -323,7 +325,9 @@ def test_batch_concatenate(
)
assert generations[2].generated_text.generated_tokens == 5
next_batch = next_batch.filter([next_batch.requests[0].id, next_batch.requests[1].id])
next_batch = next_batch.filter(
[next_batch.requests[0].id, next_batch.requests[1].id]
)
generations, next_batch = default_seq2seq_lm.generate_token(next_batch)
assert next_batch is not None
......
......@@ -39,10 +39,11 @@ class BloomCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super(BloomCausalLMBatch, cls).from_pb(
pb=pb, tokenizer=tokenizer, device=device
pb=pb, tokenizer=tokenizer, dtype=dtype, device=device
)
batch.keys_head_dim_last = False
return batch
......
......@@ -66,6 +66,7 @@ class CausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
inputs = []
......
......@@ -18,11 +18,7 @@ from text_generation_server.models.types import (
GeneratedText,
)
from text_generation_server.pb import generate_pb2
from text_generation_server.utils import (
NextTokenChooser,
StoppingCriteria,
Sampling,
)
from text_generation_server.utils import StoppingCriteria, HeterogeneousNextTokenChooser
tracer = trace.get_tracer(__name__)
......@@ -48,7 +44,7 @@ class FlashCausalLMBatch(Batch):
# All tokens
all_input_ids: List[List[int]]
all_input_ids_tensor: List[torch.Tensor]
all_input_ids_tensor: torch.Tensor
# Lengths of all generations present in the batch
input_lengths: List[int]
......@@ -56,7 +52,7 @@ class FlashCausalLMBatch(Batch):
read_offsets: List[Optional[int]]
# Generation helpers
next_token_choosers: List[NextTokenChooser]
next_token_chooser: HeterogeneousNextTokenChooser
stopping_criterias: List[StoppingCriteria]
# Maximum number of tokens this batch will grow to
......@@ -75,6 +71,7 @@ class FlashCausalLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "FlashCausalLMBatch":
position_ids = []
......@@ -87,13 +84,14 @@ class FlashCausalLMBatch(Batch):
all_input_ids = []
requests_idx_mapping = {}
next_token_choosers = []
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_length = 0
max_tokens = 0
max_length = 0
# Parse batch
for i, r in enumerate(pb.requests):
......@@ -119,7 +117,7 @@ class FlashCausalLMBatch(Batch):
# Add cumulative lengths of all previous inputs
cu_seqlens.append(cumulative_length + input_length)
next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device))
next_token_chooser_parameters.append(r.parameters)
stopping_criteria = StoppingCriteria.from_pb(
r.stopping_parameters, tokenizer
......@@ -130,11 +128,26 @@ class FlashCausalLMBatch(Batch):
# Update
cumulative_length += input_length
max_tokens += input_length + max_new_tokens
max_length = max(max_length, input_length + max_new_tokens)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype, device
)
# Padded all_input_ids_tensor
all_input_ids_tensor = np.zeros(
(len(all_input_ids), max_length), dtype=np.int64
)
for i, input_ids in enumerate(all_input_ids):
all_input_ids_tensor[i, : len(input_ids)] = input_ids
# Create tensors on device
input_ids = torch.tensor(
np.concatenate(all_input_ids), dtype=torch.int64, device=device
)
all_input_ids_tensor = torch.tensor(
all_input_ids_tensor, dtype=torch.int64, device=device
)
position_ids = torch.tensor(
np.concatenate(position_ids), dtype=torch.int32, device=device
)
......@@ -154,8 +167,8 @@ class FlashCausalLMBatch(Batch):
prefix_offsets=prefix_offsets,
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=[],
next_token_choosers=next_token_choosers,
all_input_ids_tensor=all_input_ids_tensor,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
......@@ -176,31 +189,29 @@ class FlashCausalLMBatch(Batch):
# New values after filtering
requests_idx_mapping = {}
input_ids = self.input_ids.new_empty(len(request_ids))
position_ids = self.position_ids.new_empty(len(request_ids))
# Used to index into tensors
indices = []
# Create on CPU to only move to GPU once instead of at every copy
cu_seqlens = torch.zeros(len(request_ids) + 1, dtype=torch.int32)
cu_seqlens_q = torch.arange(
0, len(request_ids) + 1, device=self.cu_seqlens_q.device, dtype=torch.int32
)
cu_seqlens_q = self.cu_seqlens_q[: len(request_ids) + 1]
max_seqlen = 0
past_key_values = []
requests = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
stopping_criterias = []
max_tokens = 0
for i, request_id in enumerate(request_ids):
idx = self.requests_idx_mapping[request_id]
indices.append(idx)
requests_idx_mapping[request_id] = i
requests.append(self.requests[idx])
......@@ -208,10 +219,6 @@ class FlashCausalLMBatch(Batch):
# Get length
request_input_length = self.input_lengths[idx]
# Copy tensors (GPU)
input_ids[i] = self.input_ids[idx]
position_ids[i] = self.position_ids[idx]
# Copy to tensor (CPU)
cu_seqlens[i + 1] = cumulative_length + request_input_length
max_seqlen = max(max_seqlen, request_input_length)
......@@ -222,14 +229,11 @@ class FlashCausalLMBatch(Batch):
)
all_input_ids.append(self.all_input_ids[idx])
all_input_ids_tensor.append(self.all_input_ids_tensor[idx])
input_lengths.append(request_input_length)
prefix_offsets.append(self.prefix_offsets[idx])
read_offsets.append(self.read_offsets[idx])
next_token_choosers.append(self.next_token_choosers[idx])
stopping_criteria = self.stopping_criterias[idx]
stopping_criterias.append(stopping_criteria)
......@@ -258,6 +262,12 @@ class FlashCausalLMBatch(Batch):
# Cat all past
past_key_values = torch.cat(past_key_values, dim=1)
# Index into tensors
input_ids = self.input_ids[indices]
position_ids = self.position_ids[indices]
all_input_ids_tensor = self.all_input_ids_tensor[indices]
next_token_chooser = self.next_token_chooser.filter(indices)
# Move to GPU now that we have the whole tensor
cu_seqlens = cu_seqlens.to(self.cu_seqlens.device)
......@@ -276,7 +286,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
......@@ -290,6 +300,7 @@ class FlashCausalLMBatch(Batch):
total_batch_size = sum([len(b) for b in batches])
dtype = batches[0].past_key_values.dtype
device = batches[0].input_ids.device
input_ids = batches[0].input_ids.new_empty(total_batch_size)
......@@ -302,19 +313,19 @@ class FlashCausalLMBatch(Batch):
past_key_values = []
all_input_ids = []
all_input_ids_tensor = []
input_lengths = []
prefix_offsets = []
read_offsets = []
next_token_choosers = []
next_token_chooser_parameters = []
stopping_criterias = []
# Cumulative length
cumulative_batch_size = 0
cumulative_length = 0
max_tokens = 0
max_length = 0
for i, batch in enumerate(batches):
requests.extend(batch.requests)
......@@ -347,25 +358,54 @@ class FlashCausalLMBatch(Batch):
)
all_input_ids.extend(batch.all_input_ids)
all_input_ids_tensor.extend(batch.all_input_ids_tensor)
input_lengths.extend(batch.input_lengths)
prefix_offsets.extend(batch.prefix_offsets)
read_offsets.extend(batch.read_offsets)
next_token_choosers.extend(batch.next_token_choosers)
next_token_chooser_parameters.extend([r.parameters for r in batch.requests])
stopping_criterias.extend(batch.stopping_criterias)
# Update
cumulative_length += batch.cu_seqlens[-1]
cumulative_batch_size += len(batch)
max_tokens += batch.max_tokens
max_length = max(
max_length,
max(
input_length
+ stopping_criteria.max_new_tokens
- stopping_criteria.current_tokens
for input_length, stopping_criteria in zip(
batch.input_lengths, batch.stopping_criterias
)
),
)
all_input_ids_tensor = torch.zeros(
(total_batch_size, max_length), dtype=torch.int64, device=device
)
cumulative_batch_size = 0
for i, batch in enumerate(batches):
start_index = cumulative_batch_size
end_index = cumulative_batch_size + len(batch)
all_input_ids_tensor[
start_index:end_index, : batch.all_input_ids_tensor.shape[1]
] = batch.all_input_ids_tensor[:, :max_length]
cumulative_batch_size += len(batch)
# Cat past
past_key_values = torch.cat(past_key_values, dim=1)
# Create final tensor on GPU
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int32, device=device)
next_token_chooser = HeterogeneousNextTokenChooser.from_pb(
next_token_chooser_parameters, dtype=dtype, device=device
)
return FlashCausalLMBatch(
batch_id=batches[0].batch_id,
requests=requests,
......@@ -381,7 +421,7 @@ class FlashCausalLMBatch(Batch):
read_offsets=read_offsets,
all_input_ids=all_input_ids,
all_input_ids_tensor=all_input_ids_tensor,
next_token_choosers=next_token_choosers,
next_token_chooser=next_token_chooser,
stopping_criterias=stopping_criterias,
max_tokens=max_tokens,
)
......@@ -463,6 +503,7 @@ class FlashCausalLM(Model):
self, batch: FlashCausalLMBatch
) -> Tuple[List[Generation], Optional[FlashCausalLMBatch]]:
prefill = batch.past_key_values is None
single_request = len(batch) == 1
if prefill and len(batch) == 1:
# Ask to pre-allocate kv to its max size
......@@ -483,6 +524,17 @@ class FlashCausalLM(Model):
pre_allocate_past_size,
)
if prefill:
next_token_logits = (
out[-1:] if single_request else out[batch.cu_seqlens[1:] - 1]
)
else:
next_token_logits = out
next_input_ids, next_token_logprobs = batch.next_token_chooser(
batch.all_input_ids_tensor[:, : batch.max_seqlen], next_token_logits
)
if prefill:
if len(batch) > 1:
# We create the prefill_tokens_indices tensor that will be used to gather prefill logprobs
......@@ -493,15 +545,11 @@ class FlashCausalLM(Model):
batch.cu_seqlens_q = torch.arange(
0, len(batch) + 1, device=self.device, dtype=torch.int32
)
next_input_ids = batch.input_ids.new_empty(len(batch))
next_position_ids = batch.position_ids.new_empty(len(batch))
else:
prefill_logprobs = None
next_input_ids = batch.input_ids
next_position_ids = batch.position_ids
next_token_logprobs = out.new_empty(len(batch))
# Prepare past for next decode
if len(batch) > 1:
# Used to slice next batch past
......@@ -552,7 +600,6 @@ class FlashCausalLM(Model):
# Zipped iterator
iterator = zip(
batch.input_lengths,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
)
......@@ -564,7 +611,6 @@ class FlashCausalLM(Model):
# For each member of the batch
for i, (
input_length,
next_token_chooser,
stopping_criteria,
all_input_ids,
) in enumerate(iterator):
......@@ -573,21 +619,6 @@ class FlashCausalLM(Model):
end_index = cumulative_length + input_length
if prefill:
# Prefill mode
# out is of shape [cumulative_sequence_lengths, vocab_size]
# only take last token logit
logits = out[end_index - 1 : end_index]
# Create all_input_ids_tensor that will be used by token warpers (for example, RepetitionPenalty)
all_input_ids_tensor = batch.input_ids.new_empty(
input_length + stopping_criteria.max_new_tokens
)
# Copy from batch.input_ids to all_input_ids_tensor
all_input_ids_tensor[:input_length] = batch.input_ids[
start_index:end_index
]
batch.all_input_ids_tensor.append(all_input_ids_tensor)
# Initialize position_ids
# In decode, we do not need this as we can just increment position ids
next_position_ids[i] = batch.position_ids[end_index - 1]
......@@ -603,25 +634,8 @@ class FlashCausalLM(Model):
prefill_tokens_indices = batch.input_ids[
start_index + 1 : end_index
]
else:
# Decode mode
# out is of shape [batch_size, vocab_size]
logits = out[i].view(1, -1)
all_input_ids_tensor = batch.all_input_ids_tensor[i]
# Select next token
next_token_id, logprob = next_token_chooser(
all_input_ids_tensor[None, :input_length], logits
)
# Add to all_input_ids_tensor
next_token_id_squeezed = next_token_id.view(1)
all_input_ids_tensor[input_length] = next_token_id_squeezed
# Set values
next_input_ids[i] = next_token_id_squeezed
next_token_logprobs[i] = logprob[-1, next_token_id].view(1)
batch.all_input_ids_tensor[i, input_length] = next_input_ids[i]
cumulative_length += input_length
......@@ -651,10 +665,11 @@ class FlashCausalLM(Model):
batch.input_lengths,
batch.prefix_offsets,
batch.read_offsets,
batch.next_token_choosers,
batch.stopping_criterias,
batch.all_input_ids,
batch.all_input_ids_tensor,
batch.next_token_chooser.do_sample,
batch.next_token_chooser.seeds,
next_token_ids,
next_token_logprobs,
)
......@@ -665,10 +680,11 @@ class FlashCausalLM(Model):
input_length,
prefix_offset,
read_offset,
next_token_chooser,
stopping_criteria,
all_input_ids,
all_input_ids_tensor,
do_sample,
seed,
next_token_id,
next_token_logprob,
) in enumerate(iterator):
......@@ -702,14 +718,11 @@ class FlashCausalLM(Model):
output_text = self.decode(
all_input_ids[-stopping_criteria.current_tokens :]
)
# Get seed
if isinstance(next_token_chooser.choice, Sampling):
seed = next_token_chooser.choice.seed
else:
seed = None
generated_text = GeneratedText(
output_text, stopping_criteria.current_tokens, reason, seed
output_text,
stopping_criteria.current_tokens,
reason,
seed if do_sample else None,
)
else:
generated_text = None
......@@ -751,8 +764,9 @@ class FlashCausalLM(Model):
batch.prefix_offsets[i] = prefix_offset
batch.read_offsets[i] = read_offset
batch.all_input_ids[i] = all_input_ids
batch.max_seqlen = batch.max_seqlen + 1
cumulative_length += input_length
batch.max_seqlen = batch.max_seqlen + 1
# No need to return a batch if we know that all requests stopped
return generations, batch if not stopped else None
......@@ -89,6 +89,7 @@ class GalacticaCausalLMBatch(CausalLMBatch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "GalacticaCausalLMBatch":
inputs = []
......
......@@ -71,6 +71,7 @@ class Seq2SeqLMBatch(Batch):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Seq2SeqLMBatch":
"""Convert a text_generation_server.v1.Batch protobuf to a Seq2SeqLMBatch"""
......
......@@ -21,6 +21,7 @@ class Batch(ABC):
cls,
pb: generate_pb2.Batch,
tokenizer: PreTrainedTokenizerBase,
dtype: torch.dtype,
device: torch.device,
) -> "Batch":
raise NotImplementedError
......
......@@ -55,7 +55,7 @@ class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
async def Prefill(self, request, context):
batch = self.model.batch_type.from_pb(
request.batch, self.model.tokenizer, self.model.device
request.batch, self.model.tokenizer, self.model.dtype, self.model.device
)
generations, next_batch = self.model.generate_token(batch)
......
......@@ -9,12 +9,13 @@ from text_generation_server.utils.hub import (
RevisionNotFoundError,
)
from text_generation_server.utils.tokens import (
Greedy,
NextTokenChooser,
Sampling,
HeterogeneousNextTokenChooser,
StoppingCriteria,
StopSequenceCriteria,
FinishReason,
Sampling,
Greedy,
)
__all__ = [
......@@ -25,6 +26,7 @@ __all__ = [
"weight_hub_files",
"download_weights",
"EntryNotFoundError",
"HeterogeneousNextTokenChooser",
"LocalEntryNotFoundError",
"RevisionNotFoundError",
"Greedy",
......
import concurrent
import time
import datetime
import torch
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from datetime import timedelta
from loguru import logger
from pathlib import Path
from safetensors.torch import load_file, save_file
from safetensors.torch import save_file
from safetensors import safe_open
from typing import Dict, List
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment