Unverified Commit 78b7465c authored by kk's avatar kk Committed by GitHub
Browse files

Fix GPU fault issue when run dsv3 with dp mode and enable torch-compile (#10361)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
parent 07bcad7f
......@@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper:
def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens
@classmethod
def get_dp_hidden_size(cls) -> int:
return cls._hidden_size
@classmethod
def get_dp_dtype(cls) -> torch.dtype:
return cls._dtype
@classmethod
def get_dp_device(cls) -> torch.device:
return cls._device
def set_dp_buffer_len(
global_dp_buffer_len: int,
......@@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
def get_dp_hidden_size() -> int:
return _DpGatheredBufferWrapper.get_dp_hidden_size()
def get_dp_dtype() -> torch.dtype:
return _DpGatheredBufferWrapper.get_dp_dtype()
def get_dp_device() -> torch.device:
return _DpGatheredBufferWrapper.get_dp_device()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention:
return tp_rank, tp_size, 0
......
......@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank,
get_attention_dp_size,
get_attention_tp_size,
get_dp_device,
get_dp_dtype,
get_dp_hidden_size,
get_global_dp_buffer,
get_local_attention_dp_size,
set_dp_buffer_len,
......@@ -187,16 +190,23 @@ class LogitsMetadata:
self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens
hidden_size = get_dp_hidden_size()
dtype = get_dp_dtype()
device = get_dp_device()
if self.global_num_tokens_for_logprob_cpu is not None:
# create a smaller buffer to reduce peak memory usage
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
else:
self.global_dp_buffer_len = self.global_dp_buffer_len
set_dp_buffer_len(
self.global_dp_buffer_len,
self.dp_local_num_tokens,
self.global_num_tokens_for_logprob_cpu,
self.gathered_buffer = torch.empty(
(
self.global_dp_buffer_len,
hidden_size,
),
dtype=dtype,
device=device,
)
......@@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = (
get_global_dp_buffer(),
logits_metadata.gathered_buffer,
hidden_states,
)
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment