Unverified Commit b87aacb5 authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

[DP Attention] Refactor: adding some utility functions (#9136)

parent b3363cc1
...@@ -678,16 +678,12 @@ class TboForwardBatchPreparer: ...@@ -678,16 +678,12 @@ class TboForwardBatchPreparer:
# TODO improve, e.g. unify w/ `init_raw` # TODO improve, e.g. unify w/ `init_raw`
if ( if (
global_server_args_dict["moe_dense_tp_size"] == 1 global_server_args_dict["moe_dense_tp_size"] == 1
and batch.gathered_buffer is not None and batch.global_dp_buffer_len is not None
): ):
sum_len = end_token_index - start_token_index sum_len = end_token_index - start_token_index
gathered_buffer = torch.zeros( global_dp_buffer_len = sum_len
(sum_len, batch.gathered_buffer.shape[1]),
dtype=batch.gathered_buffer.dtype,
device=batch.gathered_buffer.device,
)
else: else:
gathered_buffer = None global_dp_buffer_len = None
output_dict.update( output_dict.update(
dict( dict(
...@@ -706,7 +702,7 @@ class TboForwardBatchPreparer: ...@@ -706,7 +702,7 @@ class TboForwardBatchPreparer:
global_num_tokens_gpu=None, global_num_tokens_gpu=None,
global_num_tokens_cpu=None, global_num_tokens_cpu=None,
dp_padding_mode=None, dp_padding_mode=None,
gathered_buffer=gathered_buffer, global_dp_buffer_len=global_dp_buffer_len,
global_num_tokens_for_logprob_gpu=None, global_num_tokens_for_logprob_gpu=None,
global_num_tokens_for_logprob_cpu=None, global_num_tokens_for_logprob_cpu=None,
sampling_info=None, sampling_info=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment