"tests/vscode:/vscode.git/clone" did not exist on "dc86bd421ed98777112c64f61940321631c11806"
Unverified Commit 72dfa96a authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix cutlass moe accuracy drop caused by attention UB from DP padding mode (#10414)

parent 05b01ef4
...@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum): ...@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum):
return self == DpPaddingMode.SUM_LEN return self == DpPaddingMode.SUM_LEN
@classmethod @classmethod
def get_dp_padding_mode(cls, global_num_tokens: List[int]) -> DpPaddingMode: def get_dp_padding_mode(
cls, is_extend_in_batch, global_num_tokens: List[int]
) -> DpPaddingMode:
if is_extend_in_batch:
return DpPaddingMode.SUM_LEN
# we choose the mode that minimizes the communication cost # we choose the mode that minimizes the communication cost
max_len = max(global_num_tokens) max_len = max(global_num_tokens)
sum_len = sum(global_num_tokens) sum_len = sum(global_num_tokens)
......
...@@ -686,7 +686,9 @@ class ForwardBatch: ...@@ -686,7 +686,9 @@ class ForwardBatch:
(global_num_tokens[i] - 1) // attn_tp_size + 1 (global_num_tokens[i] - 1) // attn_tp_size + 1
) * attn_tp_size ) * attn_tp_size
dp_padding_mode = DpPaddingMode.get_dp_padding_mode(global_num_tokens) dp_padding_mode = DpPaddingMode.get_dp_padding_mode(
self.is_extend_in_batch, global_num_tokens
)
self.dp_padding_mode = dp_padding_mode self.dp_padding_mode = dp_padding_mode
if dp_padding_mode.is_max_len(): if dp_padding_mode.is_max_len():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment