Unverified Commit 72dfa96a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix cutlass moe accuracy drop caused by attention UB from DP padding mode (#10414)

parent 05b01ef4
......@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum):
return self == DpPaddingMode.SUM_LEN
@classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode:
def get_dp_padding_mode(
cls, is_extend_in_batch, global_num_tokens: List[int]
) -> DpPaddingMode:
if is_extend_in_batch:
return DpPaddingMode.SUM_LEN
# we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens)
......
......@@ -686,7 +686,9 @@ class ForwardBatch:
(global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens)
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
self.is_extend_in_batch, global_num_tokens
)
self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment