Unverified Commit b8df43ab authored by Qiaolin Yu's avatar Qiaolin Yu Committed by GitHub
Browse files

Fix gathered_buffer issues in tbo (#7531)

parent a1c1ebe9
......@@ -346,7 +346,10 @@ class TboForwardBatchPreparer:
)
# TODO improve, e.g. unify w/ `init_raw`
if global_server_args_dict["moe_dense_tp_size"] == 1:
if (
global_server_args_dict["moe_dense_tp_size"] == 1
and batch.gathered_buffer is not None
):
sum_len = end_token_index - start_token_index
gathered_buffer = torch.zeros(
(sum_len, batch.gathered_buffer.shape[1]),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment