Commit c49740a3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5-zero_overhead' into 'v0.8.5.post1-dev'

debug and fix tbo error in mtp

See merge request dcutoolkit/deeplearing/vllm!124
parents 385eeae9 bf790acd
......@@ -102,7 +102,7 @@ def split_model_input(model_input, self_device, batch_size_left, batch_size_righ
previous_hidden_states_left = None
previous_hidden_states_right = None
if model_input.previous_hidden_states != None:
split_previous_hidden_states = torch.split(model_input.previous_hidden_states, batch_size_split, dim=0)
split_previous_hidden_states = torch.split(model_input.previous_hidden_states, query_tokens_split, dim=0)
previous_hidden_states_left = split_previous_hidden_states[0]
previous_hidden_states_right = split_previous_hidden_states[1]
......
......@@ -250,12 +250,12 @@ def tbo_model_executable(
model_input_left, model_input_right = split_model_input(model_input, self_device, batch_size_left, batch_size_right)
model_kwargs_left = model_kwargs
model_kwargs_right = model_kwargs
model_kwargs_left = model_kwargs.copy()
model_kwargs_right = model_kwargs.copy()
if "previous_hidden_states" in model_kwargs:
previous_hidden_states = model_kwargs["previous_hidden_states"]
batch_size_split = [batch_size_left, batch_size_right]
split_previous_hidden_states = torch.split(previous_hidden_states, batch_size_split, dim=0)
query_tokens_split = [sum(model_input.query_lens[0:batch_size_left]), sum(model_input.query_lens[batch_size_left:])]
split_previous_hidden_states = torch.split(previous_hidden_states, query_tokens_split, dim=0)
model_kwargs_left["previous_hidden_states"] = split_previous_hidden_states[0]
model_kwargs_right["previous_hidden_states"] = split_previous_hidden_states[1]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment