Commit 12291212 authored by maxiao1's avatar maxiao1 Committed by lizhigong
Browse files

pd分离_tbo

parent 3daae57c
......@@ -414,10 +414,10 @@ def unified_attention(
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# if envs.VLLM_ENABLE_TBO:
# tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
......@@ -462,10 +462,10 @@ def unified_attention_with_output(
attn_metadata,
output=output,
output_scale=output_scale)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# if envs.VLLM_ENABLE_TBO:
# tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
# else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def unified_attention_with_output_fake(
......
......@@ -75,8 +75,11 @@ class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
if not torch.compiler.is_compiling(): # 非 capture 阶段
return self.forward_cuda(x) # 强制走 fused kernel
else:
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
......
......@@ -165,38 +165,40 @@ class RMSNorm(CustomOp):
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
if not torch.compiler.is_compiling(): # 非 capture 阶段
return self.forward_cuda(x, residual) # 强制走 fused kernel
else:
return x, residual
# 否则fallback到原始实现
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
hidden_size = x.shape[-1]
if hidden_size != self.hidden_size:
raise ValueError("Expected hidden_size to be "
f"{self.hidden_size}, but found: {hidden_size}")
if self.variance_size_override is None:
x_var = x
else:
if hidden_size < self.variance_size_override:
raise ValueError(
"Expected hidden_size to be at least "
f"{self.variance_size_override}, but found: {hidden_size}")
x_var = x[:, :, :self.variance_size_override]
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype)
if self.has_weight:
x = x * self.weight
if residual is None:
return x
else:
return x, residual
def forward_cuda(
self,
......
......@@ -17,10 +17,12 @@ logger = init_logger(__name__)
tbo_step_stream = None
all_reduce_stream = None
class TwoBatchOverlap():
PERSIST_THREADS = os.getenv('VLLM_TBO_PERSIST_THREADS', '1') not in ('0','false','False','no','NO','')
STOP = object()
class TwoBatchOverlap:
def __init__(self):
global tbo_step_stream
global all_reduce_stream
global tbo_step_stream, all_reduce_stream
self.model_input_left_queue = queue.Queue()
self.model_input_right_queue = queue.Queue()
self.states_left_queue = queue.Queue()
......@@ -29,12 +31,14 @@ class TwoBatchOverlap():
self.right_thread = None
self.left_tid = 0
self.right_tid = 0
self._stop_evt = threading.Event()
self._threads_started = False
self.sem_left = threading.Semaphore(0)
self.sem_right = threading.Semaphore(0)
self.left_first = False
self.tbo_running = False
self.tbo_in_capture = False
if tbo_step_stream == None:
if tbo_step_stream is None:
tbo_step_stream = torch.cuda.Stream()
all_reduce_stream = torch.cuda.Stream()
self.step_event = torch.cuda.Event(enable_timing=False)
......@@ -44,60 +48,85 @@ class TwoBatchOverlap():
self.event_right_t2c = torch.cuda.Event(enable_timing=False)
def init_tbo_thread(self):
self.model_input_left_queue.empty()
self.model_input_right_queue.empty()
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_left_queue,))
self.left_thread.start()
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap, args=(self.model_input_right_queue,))
self.right_thread.start()
if get_tp_group().rank == 0:
logger.info('tbo:two batch overlap start')
def finish_thread(self):
self.left_thread.join()
self.left_thread = None
self.right_thread.join()
self.right_thread = None
if self._threads_started and PERSIST_THREADS:
return
if self.left_thread is None or not self.left_thread.is_alive():
self.left_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_left_queue,), daemon=True)
self.left_thread.start()
if self.right_thread is None or not self.right_thread.is_alive():
self.right_thread = threading.Thread(target=self.thread_two_batch_overlap,
args=(self.model_input_right_queue,), daemon=True)
self.right_thread.start()
self._threads_started = True
def shutdown(self, timeout=5.0):
self._stop_evt.set()
try:
self.model_input_left_queue.put(STOP)
self.model_input_right_queue.put(STOP)
except Exception:
pass
if self.left_thread is not None:
self.left_thread.join(timeout=timeout)
self.left_thread = None
if self.right_thread is not None:
self.right_thread.join(timeout=timeout)
self.right_thread = None
@torch.inference_mode()
def thread_two_batch_overlap(self, queue):
def thread_two_batch_overlap(self, q):
is_left_thread = False
tid = threading.get_ident()
if queue == self.model_input_left_queue:
if q is self.model_input_left_queue:
self.left_tid = tid
is_left_thread = True
init_tbo_forward_context(True, self.left_tid)
else:
self.right_tid = tid
init_tbo_forward_context(False, self.right_tid)
with torch.cuda.stream(tbo_step_stream):
queue.get()
self.tbo_thread_synchronize(tid)
if is_left_thread:
attn_metadata = self.attn_metadata_left
num_input_tokens = self.num_input_tokens_left
input_ids = self.input_ids_left
positions = self.positions_left
else:
attn_metadata = self.attn_metadata_right
num_input_tokens = self.num_input_tokens_right
input_ids = self.input_ids_right
positions = self.positions_right
model_output = None
# Run the decoder.
# Use persistent buffers for CUDA graphs.
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=self.intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
while not self._stop_evt.is_set():
item = q.get()
if item is STOP:
break
with torch.cuda.stream(tbo_step_stream):
self.tbo_thread_synchronize(tid)
if is_left_thread:
attn_metadata = self.attn_metadata_left
num_input_tokens = self.num_input_tokens_left
input_ids = self.input_ids_left
positions = self.positions_left
else:
attn_metadata = self.attn_metadata_right
num_input_tokens = self.num_input_tokens_right
input_ids = self.input_ids_right
positions = self.positions_right
# Select per-thread tensors (left/right) with backward-compatible fallback
if is_left_thread:
intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
else:
intermediate_tensors = getattr(self, 'intermediate_tensors_right', None)
if intermediate_tensors is None:
intermediate_tensors = getattr(self, 'intermediate_tensors_left', None)
with set_forward_context(attn_metadata,
self.model_runner.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=self.num_tokens_across_dp,
skip_cuda_graphs=True,
):
model_output = self.model_runner.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=self.inputs_embeds,
)
if is_left_thread:
self.sem_right.release()
self.states_left_queue.put(model_output)
......@@ -117,18 +146,19 @@ class TwoBatchOverlap():
return self.event_right_c2t, self.event_right_t2c
def set_model_input(self,
model_runner,
attn_metadata_left,
attn_metadata_right,
model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds):
inputs_embeds,
):
self.model_runner = model_runner
self.attn_metadata_left = attn_metadata_left
self.attn_metadata_right = attn_metadata_right
......@@ -139,26 +169,34 @@ class TwoBatchOverlap():
self.positions_left = positions_left
self.positions_right = positions_right
self.num_tokens_across_dp = num_tokens_across_dp
self.intermediate_tensors = intermediate_tensors
self.inputs_embeds = inputs_embeds
if isinstance(intermediate_tensors, tuple):
self.intermediate_tensors_left, self.intermediate_tensors_right = intermediate_tensors
else:
self.intermediate_tensors_left = intermediate_tensors
self.intermediate_tensors_right = None
self.model_input_left_queue.put(None)
self.model_input_right_queue.put(None)
def get_model_output(self):
states_left = self.states_left_queue.get()
states_right = self.states_right_queue.get()
return states_left, states_right
tbo_obj_v1 = None
def is_enable_tbo_v1():
global tbo_obj_v1
return tbo_obj_v1 != None
return tbo_obj_v1 is not None
def init_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 == None:
if tbo_obj_v1 is None:
tbo_obj_v1 = TwoBatchOverlap()
tbo_obj_v1.init_tbo_thread()
......@@ -171,7 +209,7 @@ def tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache):
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
def tbo_all_reduce_v1(obj):
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 != None and tbo_obj_v1.tbo_running:
if envs.VLLM_ENABLE_TBO and tbo_obj_v1 is not None and tbo_obj_v1.tbo_running:
tid = threading.get_ident()
if tid == tbo_obj_v1.left_tid:
event_c2t, event_t2c = tbo_obj_v1.event_left_c2t, tbo_obj_v1.event_left_t2c
......@@ -185,7 +223,7 @@ def tbo_all_reduce_v1(obj):
tbo_obj_v1.tbo_thread_synchronize(tid)
tbo_step_stream.wait_event(event_t2c)
return output
return tensor_model_parallel_all_reduce(obj)
return tensor_model_parallel_all_reduce(obj)
def merge_model_output(states_left, states_right):
if isinstance(states_left, IntermediateTensors):
......@@ -199,45 +237,53 @@ def merge_model_output(states_left, states_right):
def tbo_model_executable_v1(
model_runner,
attn_metadata_left,
attn_metadata_right,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
num_tokens_across_dp,
input_ids,
positions,
intermediate_tensors,
inputs_embeds
inputs_embeds,
):
init_two_batch_overlap()
tbo_obj_v1.tbo_running = True
tbo_obj_v1.left_first = True
tbo_obj_v1.step_event.record()
current_stream = torch.cuda.current_stream()
num_total_tokens = num_input_tokens_left + num_input_tokens_right
with torch.cuda.stream(tbo_step_stream):
tbo_step_stream.wait_event(tbo_obj_v1.step_event)
tokens_split = [num_input_tokens_left, num_input_tokens_right]
input_ids_left, input_ids_right = torch.split(input_ids, tokens_split, dim=0)
positions_left, positions_right = torch.split(positions, tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds)
input_ids_left, input_ids_right = torch.split(input_ids[:num_total_tokens], tokens_split, dim=0)
positions_left, positions_right = torch.split(positions[:num_total_tokens], tokens_split, dim=0)
tbo_obj_v1.set_model_input(model_runner,
attn_metadata_left,
attn_metadata_right,
num_input_tokens_left,
num_input_tokens_right,
input_ids_left,
input_ids_right,
positions_left,
positions_right,
num_tokens_across_dp,
intermediate_tensors,
inputs_embeds,
)
model_output_left, model_output_right = tbo_obj_v1.get_model_output()
hidden_or_intermediate_states = merge_model_output(model_output_left, model_output_right)
tbo_obj_v1.tbo_running = False
tbo_obj_v1.step_event.record()
tbo_obj_v1.finish_thread()
current_stream.wait_event(tbo_obj_v1.step_event)
return hidden_or_intermediate_states
\ No newline at end of file
return hidden_or_intermediate_states
def finalize_two_batch_overlap():
global tbo_obj_v1
if tbo_obj_v1 is not None:
try:
tbo_obj_v1.shutdown()
finally:
tbo_obj_v1 = None
......@@ -1374,8 +1374,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs = self.full_cuda_graph and not attention_cuda_graphs
if envs.VLLM_ENABLE_TBO and (not self.use_cuda_graph or skip_cuda_graphs):
if envs.VLLM_ENABLE_TBO and scheduler_output.total_num_scheduled_tokens >= envs.VLLM_TBO_MIN_TOKENS:
model_output, finished_sending, finished_recving = \
tbo_split_and_execute_model(self, attn_metadata, num_input_tokens,
num_tokens_across_dp, input_ids, positions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment