Unverified Commit 72dfa96a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix cutlass moe accuracy drop caused by attention UB from DP padding mode (#10414)

parent 05b01ef4
...@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum): ...@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum):
return self == DpPaddingMode.SUM_LEN return self == DpPaddingMode.SUM_LEN
@classmethod @classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode: def get_dp_padding_mode(
cls, is_extend_in_batch, global_num_tokens: List[int]
) -> DpPaddingMode:
if is_extend_in_batch:
return DpPaddingMode.SUM_LEN
# we choose the mode that minimizes the communication cost # we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens) max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens) sum_len = sum(global_num_tokens)
......
...@@ -686,7 +686,9 @@ class ForwardBatch: ...@@ -686,7 +686,9 @@ class ForwardBatch:
(global_num_tokens[i] - 1) // attn_tp_size + 1 (global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size ) * attn_tp_size
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens) dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
self.is_extend_in_batch, global_num_tokens
)
self.dp_padding_mode = dp_padding_mode self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len(): if dp_padding_mode.is_max_len():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment