Unverified Commit 78b7465c authored by kk's avatar kk Committed by GitHub
Browse files

Fix GPU fault issue when run dsv3 with dp mode and enable torch-compile (#10361)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
parent 07bcad7f
...@@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper: ...@@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper:
def get_dp_global_num_tokens(cls) -> List[int]: def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens return cls._global_num_tokens
@classmethod
def get_dp_hidden_size(cls) -> int:
return cls._hidden_size
@classmethod
def get_dp_dtype(cls) -> torch.dtype:
return cls._dtype
@classmethod
def get_dp_device(cls) -> torch.device:
return cls._device
def set_dp_buffer_len( def set_dp_buffer_len(
global_dp_buffer_len: int, global_dp_buffer_len: int,
...@@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]: ...@@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens() return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
def get_dp_hidden_size() -> int:
return _DpGatheredBufferWrapper.get_dp_hidden_size()
def get_dp_dtype() -> torch.dtype:
return _DpGatheredBufferWrapper.get_dp_dtype()
def get_dp_device() -> torch.device:
return _DpGatheredBufferWrapper.get_dp_device()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention: if not enable_dp_attention:
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
......
...@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import ( ...@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank, get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_size, get_attention_tp_size,
get_dp_device,
get_dp_dtype,
get_dp_hidden_size,
get_global_dp_buffer, get_global_dp_buffer,
get_local_attention_dp_size, get_local_attention_dp_size,
set_dp_buffer_len, set_dp_buffer_len,
...@@ -187,16 +190,23 @@ class LogitsMetadata: ...@@ -187,16 +190,23 @@ class LogitsMetadata:
self.dp_local_start_pos = dp_local_start_pos self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens self.dp_local_num_tokens = dp_local_num_tokens
hidden_size = get_dp_hidden_size()
dtype = get_dp_dtype()
device = get_dp_device()
if self.global_num_tokens_for_logprob_cpu is not None: if self.global_num_tokens_for_logprob_cpu is not None:
# create a smaller buffer to reduce peak memory usage # create a smaller buffer to reduce peak memory usage
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu) self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
else: else:
self.global_dp_buffer_len = self.global_dp_buffer_len self.global_dp_buffer_len = self.global_dp_buffer_len
set_dp_buffer_len( self.gathered_buffer = torch.empty(
(
self.global_dp_buffer_len, self.global_dp_buffer_len,
self.dp_local_num_tokens, hidden_size,
self.global_num_tokens_for_logprob_cpu, ),
dtype=dtype,
device=device,
) )
...@@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module): ...@@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn: if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata() logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
get_global_dp_buffer(), logits_metadata.gathered_buffer,
hidden_states, hidden_states,
) )
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment