Unverified Commit b87aacb5 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[DP Attention] Refactor: adding some utility functions (#9136)

parent b3363cc1
......@@ -678,16 +678,12 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw`
if (
global_server_args_dict["moe_dense_tp_size"] == 1
and batch.gathered_buffer is not None
and batch.global_dp_buffer_len is not None
):
sum_len = end_token_index - start_token_index
gathered_buffer = torch.zeros(
(sum_len, batch.gathered_buffer.shape[1]),
dtype=batch.gathered_buffer.dtype,
device=batch.gathered_buffer.device,
)
global_dp_buffer_len = sum_len
else:
gathered_buffer = None
global_dp_buffer_len = None
output_dict.update(
dict(
......@@ -706,7 +702,7 @@ class TboForwardBatchPreparer:
global_num_tokens_gpu=None,
global_num_tokens_cpu=None,
dp_padding_mode=None,
gathered_buffer=gathered_buffer,
global_dp_buffer_len=global_dp_buffer_len,
global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None,
sampling_info=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment