Unverified Commit bb3a3b66 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Support Faster JSON decoding for llava (#137)

When sending fast-forwarded reqs to model_rpc, re-calculate `pad_input_ids`
parent 45d6592d
...@@ -31,6 +31,7 @@ class Req: ...@@ -31,6 +31,7 @@ class Req:
self.pixel_values = None self.pixel_values = None
self.image_size = None self.image_size = None
self.image_offset = 0 self.image_offset = 0
self.pad_value = None
self.sampling_params = None self.sampling_params = None
self.return_logprob = False self.return_logprob = False
...@@ -58,7 +59,7 @@ class Req: ...@@ -58,7 +59,7 @@ class Req:
def max_new_tokens(self): def max_new_tokens(self):
return self.sampling_params.max_new_tokens return self.sampling_params.max_new_tokens
def tokenize_fast_forward(self, fast_forward_str, next_state): def fast_forward_and_retokenize(self, fast_forward_str, next_state):
old_output_str = self.tokenizer.decode(self.output_ids) old_output_str = self.tokenizer.decode(self.output_ids)
# FIXME: This logic does not really solve the problem of determining whether # FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space. # there should be a leading space.
...@@ -75,9 +76,14 @@ class Req: ...@@ -75,9 +76,14 @@ class Req:
+ fast_forward_str + fast_forward_str
) )
new_input_ids = self.tokenizer.encode(new_input_string) new_input_ids = self.tokenizer.encode(new_input_string)
if self.pixel_values is not None:
# NOTE: This is a hack because the old input_ids contains the image padding
fast_forward_tokens_len = len(self.tokenizer.encode(fast_forward_str))
else:
fast_forward_tokens_len = ( fast_forward_tokens_len = (
len(new_input_ids) - len(self.input_ids) - len(self.output_ids) len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
) )
# print("=" * 100) # print("=" * 100)
# print(f"Catch fast forward:\n{fast_forward_str}") # print(f"Catch fast forward:\n{fast_forward_str}")
# print(self.tokenizer.convert_ids_to_tokens(self.input_ids)) # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
...@@ -351,7 +357,7 @@ class Batch: ...@@ -351,7 +357,7 @@ class Batch:
self.tree_cache.dec_ref_counter(req.last_node) self.tree_cache.dec_ref_counter(req.last_node)
# fast forward # fast forward
req.tokenize_fast_forward(fast_forward_str, next_state) req.fast_forward_and_retokenize(fast_forward_str, next_state)
fast_forward_reqs.append(req) fast_forward_reqs.append(req)
filter_indices.remove(i) filter_indices.remove(i)
......
...@@ -83,7 +83,9 @@ class ModelRpcServer(rpyc.Service): ...@@ -83,7 +83,9 @@ class ModelRpcServer(rpyc.Service):
self.max_num_running_seq = self.max_total_num_token // 2 self.max_num_running_seq = self.max_total_num_token // 2
self.max_prefill_num_token = max( self.max_prefill_num_token = max(
self.model_config.context_len, self.model_config.context_len,
self.max_total_num_token // 6 if server_args.max_prefill_num_token is None else server_args.max_prefill_num_token, self.max_total_num_token // 6
if server_args.max_prefill_num_token is None
else server_args.max_prefill_num_token,
) )
self.int_token_logit_bias = torch.tensor( self.int_token_logit_bias = torch.tensor(
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size) get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
...@@ -233,7 +235,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -233,7 +235,7 @@ class ModelRpcServer(rpyc.Service):
req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
req.pixel_values = recv_req.pixel_values req.pixel_values = recv_req.pixel_values
if req.pixel_values is not None: if req.pixel_values is not None:
pad_value = [ req.pad_value = [
(recv_req.image_hash) % self.model_config.vocab_size, (recv_req.image_hash) % self.model_config.vocab_size,
(recv_req.image_hash >> 16) % self.model_config.vocab_size, (recv_req.image_hash >> 16) % self.model_config.vocab_size,
(recv_req.image_hash >> 32) % self.model_config.vocab_size, (recv_req.image_hash >> 32) % self.model_config.vocab_size,
...@@ -241,7 +243,7 @@ class ModelRpcServer(rpyc.Service): ...@@ -241,7 +243,7 @@ class ModelRpcServer(rpyc.Service):
] ]
req.image_size = recv_req.image_size req.image_size = recv_req.image_size
req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids( req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
req.input_ids, pad_value, req.pixel_values.shape, req.image_size req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
) )
req.sampling_params = recv_req.sampling_params req.sampling_params = recv_req.sampling_params
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
...@@ -438,6 +440,20 @@ class ModelRpcServer(rpyc.Service): ...@@ -438,6 +440,20 @@ class ModelRpcServer(rpyc.Service):
if not self.no_regex_fast_forward: if not self.no_regex_fast_forward:
# check for fast forward # check for fast forward
fast_forward_reqs = batch.check_for_fast_forward() fast_forward_reqs = batch.check_for_fast_forward()
# check for image fast forward
for req in fast_forward_reqs:
if req.pixel_values is not None:
(
req.input_ids,
req.image_offset,
) = self.model_runner.model.pad_input_ids(
req.input_ids,
req.pad_value,
req.pixel_values.shape,
req.image_size,
)
self.forward_queue.extend(fast_forward_reqs) self.forward_queue.extend(fast_forward_reqs)
if batch.is_empty(): if batch.is_empty():
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment