Unverified Commit 5a261bd0 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Fix the deadlock in multi-node tp (#1122)

parent 6aa8ad14
......@@ -64,7 +64,9 @@ def main(args):
@sgl.function
def few_shot_gsm8k(s, question):
s += few_shot_examples + question
s += sgl.gen("answer", max_tokens=512, stop=["Question", "Assistant:"])
s += sgl.gen(
"answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"]
)
#####################################
########## SGL Program End ##########
......
......@@ -67,10 +67,12 @@ class LogitsMetadata:
class LogitsProcessor(nn.Module):
def __init__(self, config):
def __init__(self, config, skip_all_gather: bool = False):
super().__init__()
self.config = config
self.tp_size = get_tensor_model_parallel_world_size()
self.do_tensor_parallel_all_gather = (
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
)
def _get_normalized_prompt_logprobs(
self, input_token_logprobs, logits_metadata: LogitsMetadata
......@@ -159,7 +161,7 @@ class LogitsProcessor(nn.Module):
last_hidden = hidden_states[last_index]
last_logits = torch.matmul(last_hidden, weight.T)
if self.tp_size > 1:
if self.do_tensor_parallel_all_gather:
last_logits = tensor_model_parallel_all_gather(last_logits)
last_logits = last_logits[:, : self.config.vocab_size].float()
......@@ -204,7 +206,7 @@ class LogitsProcessor(nn.Module):
)
else:
all_logits = torch.matmul(hidden_states, weight.T)
if self.tp_size > 1:
if self.do_tensor_parallel_all_gather:
all_logits = tensor_model_parallel_all_gather(all_logits)
all_logits = all_logits[:, : self.config.vocab_size].float()
......
......@@ -21,7 +21,9 @@ from dataclasses import dataclass
from typing import List, Optional, Union
import torch
import torch.distributed as dist
from flashinfer.sampling import top_k_top_p_sampling_from_probs
from vllm.distributed import get_tensor_model_parallel_group
import sglang.srt.sampling.penaltylib as penaltylib
from sglang.global_config import global_config
......@@ -724,7 +726,7 @@ class ScheduleBatch:
)
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
def sample(self, logits: torch.Tensor):
def sample(self, logits: torch.Tensor, is_multi_node_tp=False):
# TODO(lsyin): move this into a part of layer and run with CUDA Graph
# Post process logits
logits = logits.contiguous()
......@@ -779,6 +781,16 @@ class ScheduleBatch:
self.penalizer_orchestrator.cumulate_output_tokens(batch_next_token_ids)
if is_multi_node_tp:
# If the tensor parallelism spans across multiple nodes, there is some indeterminism
# that can cause the TP workers to generate different tokens, so we need to
# sync here
torch.distributed.all_reduce(
batch_next_token_ids,
op=dist.ReduceOp.MIN,
group=get_tensor_model_parallel_group().device_group,
)
return batch_next_token_ids
......
......@@ -85,10 +85,6 @@ class ModelTpServer:
self.schedule_policy = server_args.schedule_policy
self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
# Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
# Init model and tokenizer
self.model_config = ModelConfig(
server_args.model_path,
......@@ -175,6 +171,10 @@ class ModelTpServer:
self.num_generated_tokens = 0
self.last_stats_tic = time.time()
# Chunked prefill
self.chunked_prefill_size = server_args.chunked_prefill_size
self.current_inflight_req = None
# Init the FSM cache for constrained generation
if not server_args.skip_tokenizer_init:
self.regex_fsm_cache = FSMCache(
......@@ -444,7 +444,9 @@ class ModelTpServer:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(
output.next_token_logits, self.model_runner.is_multi_node_tp
)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
......@@ -603,7 +605,9 @@ class ModelTpServer:
# Forward and sample the next tokens
output = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids = batch.sample(output.next_token_logits)
next_token_ids = batch.sample(
output.next_token_logits, self.model_runner.is_multi_node_tp
)
# Move logprobs to cpu
if output.next_token_logprobs is not None:
......
......@@ -142,7 +142,7 @@ class CudaGraphRunner:
set_torch_compile_config()
def can_run(self, batch_size):
return batch_size < self.max_bs
return batch_size <= self.max_bs
def capture(self, batch_size_list):
self.batch_size_list = batch_size_list
......@@ -239,12 +239,23 @@ class CudaGraphRunner:
return forward(input_ids, input_metadata.positions, input_metadata)
for _ in range(2):
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
with torch.cuda.graph(graph, pool=self.graph_memory_pool, stream=stream):
out = run_once()
torch.cuda.synchronize()
self.model_runner.tp_group.barrier()
self.graph_memory_pool = graph.pool()
return graph, None, out, flashinfer_decode_wrapper
......@@ -278,7 +289,9 @@ class CudaGraphRunner:
)
# Replay
torch.cuda.synchronize()
self.graphs[bs].replay()
torch.cuda.synchronize()
output = self.output_buffers[bs]
# Unpad
......
......@@ -38,6 +38,7 @@ from vllm.distributed import (
init_distributed_environment,
initialize_model_parallel,
)
from vllm.distributed.parallel_state import in_the_same_node_as
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
......@@ -112,10 +113,13 @@ class ModelRunner:
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
self.tp_group = get_tp_group()
total_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
self.is_multi_node_tp = not all(
in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)
)
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
......
......@@ -295,8 +295,9 @@ class Grok1ModelForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment