"vscode:/vscode.git/clone" did not exist on "196835695ed6fa3ec53b888088d9d5581e8f8e94"
Unverified Commit defede50 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix DeepSeek DP Attention + torch compile (#5367)


Co-authored-by: default avatarispobock <ispobaoke@163.com>
parent fc728719
...@@ -192,8 +192,7 @@ def _dp_gather( ...@@ -192,8 +192,7 @@ def _dp_gather(
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0): if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
assert ( assert (
global_tokens.untyped_storage().data_ptr() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
!= local_tokens.untyped_storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed" ), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton( memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
...@@ -243,8 +242,7 @@ def dp_scatter( ...@@ -243,8 +242,7 @@ def dp_scatter(
assert global_tokens.is_contiguous() assert global_tokens.is_contiguous()
if local_tokens.shape[0] > 0: if local_tokens.shape[0] > 0:
assert ( assert (
local_tokens.untyped_storage().data_ptr() local_tokens.untyped_storage() is not global_tokens.untyped_storage()
!= global_tokens.untyped_storage().data_ptr()
), "aliasing between local_tokens and global_tokens not allowed" ), "aliasing between local_tokens and global_tokens not allowed"
memcpy_triton( memcpy_triton(
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
......
import json
import pandas as pd
import argparse import argparse
import json
import os import os
import pandas as pd
from tabulate import tabulate from tabulate import tabulate
# Parse command-line arguments # Parse command-line arguments
......
...@@ -28,6 +28,9 @@ class TestDPAttentionDP2TP2(CustomTestCase): ...@@ -28,6 +28,9 @@ class TestDPAttentionDP2TP2(CustomTestCase):
"--enable-dp-attention", "--enable-dp-attention",
"--dp", "--dp",
"2", "2",
"--enable-torch-compile",
"--torch-compile-max-bs",
"2",
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment