"git@developer.sourcefind.cn:change/sglang.git" did not exist on "862dd76c76191b0d19b97cffa8d51c6cc3a466eb"
Unverified Commit 78b7465c authored by kk's avatar kk Committed by GitHub
Browse files

Fix GPU fault issue when run dsv3 with dp mode and enable torch-compile (#10361)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
parent 07bcad7f
...@@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper: ...@@ -119,6 +119,18 @@ class _DpGatheredBufferWrapper:
def get_dp_global_num_tokens(cls) -> List[int]: def get_dp_global_num_tokens(cls) -> List[int]:
return cls._global_num_tokens return cls._global_num_tokens
@classmethod
def get_dp_hidden_size(cls) -> int:
return cls._hidden_size
@classmethod
def get_dp_dtype(cls) -> torch.dtype:
return cls._dtype
@classmethod
def get_dp_device(cls) -> torch.device:
return cls._device
def set_dp_buffer_len( def set_dp_buffer_len(
global_dp_buffer_len: int, global_dp_buffer_len: int,
...@@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]: ...@@ -150,6 +162,18 @@ def get_dp_global_num_tokens() -> List[int]:
return _DpGatheredBufferWrapper.get_dp_global_num_tokens() return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
def get_dp_hidden_size() -> int:
return _DpGatheredBufferWrapper.get_dp_hidden_size()
def get_dp_dtype() -> torch.dtype:
return _DpGatheredBufferWrapper.get_dp_dtype()
def get_dp_device() -> torch.device:
return _DpGatheredBufferWrapper.get_dp_device()
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
if not enable_dp_attention: if not enable_dp_attention:
return tp_rank, tp_size, 0 return tp_rank, tp_size, 0
......
...@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import ( ...@@ -35,6 +35,9 @@ from sglang.srt.layers.dp_attention import (
get_attention_dp_rank, get_attention_dp_rank,
get_attention_dp_size, get_attention_dp_size,
get_attention_tp_size, get_attention_tp_size,
get_dp_device,
get_dp_dtype,
get_dp_hidden_size,
get_global_dp_buffer, get_global_dp_buffer,
get_local_attention_dp_size, get_local_attention_dp_size,
set_dp_buffer_len, set_dp_buffer_len,
...@@ -187,16 +190,23 @@ class LogitsMetadata: ...@@ -187,16 +190,23 @@ class LogitsMetadata:
self.dp_local_start_pos = dp_local_start_pos self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens self.dp_local_num_tokens = dp_local_num_tokens
hidden_size = get_dp_hidden_size()
dtype = get_dp_dtype()
device = get_dp_device()
if self.global_num_tokens_for_logprob_cpu is not None: if self.global_num_tokens_for_logprob_cpu is not None:
# create a smaller buffer to reduce peak memory usage # create a smaller buffer to reduce peak memory usage
self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu) self.global_dp_buffer_len = sum(self.global_num_tokens_for_logprob_cpu)
else: else:
self.global_dp_buffer_len = self.global_dp_buffer_len self.global_dp_buffer_len = self.global_dp_buffer_len
set_dp_buffer_len( self.gathered_buffer = torch.empty(
self.global_dp_buffer_len, (
self.dp_local_num_tokens, self.global_dp_buffer_len,
self.global_num_tokens_for_logprob_cpu, hidden_size,
),
dtype=dtype,
device=device,
) )
...@@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module): ...@@ -443,7 +453,7 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn: if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata() logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
get_global_dp_buffer(), logits_metadata.gathered_buffer,
hidden_states, hidden_states,
) )
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment