Unverified Commit e236d8fe authored by Cheng Wan's avatar Cheng Wan Committed by GitHub
Browse files

Save peak memory in logits processor (#8343)

parent 4fa44d63
...@@ -170,8 +170,6 @@ class LogitsMetadata: ...@@ -170,8 +170,6 @@ class LogitsMetadata:
) )
def compute_dp_attention_metadata(self): def compute_dp_attention_metadata(self):
# TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
# we may use a smaller buffer in draft extend.
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0) cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
dp_rank = get_attention_dp_rank() dp_rank = get_attention_dp_rank()
...@@ -186,6 +184,19 @@ class LogitsMetadata: ...@@ -186,6 +184,19 @@ class LogitsMetadata:
self.dp_local_start_pos = dp_local_start_pos self.dp_local_start_pos = dp_local_start_pos
self.dp_local_num_tokens = dp_local_num_tokens self.dp_local_num_tokens = dp_local_num_tokens
if self.global_num_tokens_for_logprob_cpu is not None:
# create a smaller buffer to reduce peak memory usage
self.gathered_buffer = torch.empty(
(
sum(self.global_num_tokens_for_logprob_cpu),
self.gathered_buffer.shape[1],
),
dtype=self.gathered_buffer.dtype,
device=self.gathered_buffer.device,
)
else:
self.gathered_buffer = torch.empty_like(self.gathered_buffer)
class LogitsProcessor(nn.Module): class LogitsProcessor(nn.Module):
def __init__( def __init__(
...@@ -430,7 +441,7 @@ class LogitsProcessor(nn.Module): ...@@ -430,7 +441,7 @@ class LogitsProcessor(nn.Module):
if self.do_tensor_parallel_all_gather_dp_attn: if self.do_tensor_parallel_all_gather_dp_attn:
logits_metadata.compute_dp_attention_metadata() logits_metadata.compute_dp_attention_metadata()
hidden_states, local_hidden_states = ( hidden_states, local_hidden_states = (
torch.empty_like(logits_metadata.gathered_buffer), logits_metadata.gathered_buffer,
hidden_states, hidden_states,
) )
dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata) dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment