Unverified Commit 300acb83 authored by Yan Burman's avatar Yan Burman Committed by GitHub
Browse files

[Core][Bugfix] Use correct device to initialize GPU data during CUDA-graph-capture (#11233)


Signed-off-by: default avatarYan Burman <yanburman@users.noreply.github.com>
Signed-off-by: default avatarIdo Asraff <idoa@atero.ai>
parent d91457d5
...@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port): ...@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for sz in test_sizes: for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]: for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture() as graph_capture_context: with graph_capture(device=device) as graph_capture_context:
# use integers so result matches NCCL exactly # use integers so result matches NCCL exactly
inp1 = torch.randint(1, inp1 = torch.randint(1,
16, (sz, ), 16, (sz, ),
......
...@@ -107,7 +107,7 @@ def multiple_allreduce_with_vllm_worker_fn(): ...@@ -107,7 +107,7 @@ def multiple_allreduce_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}") device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2) ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device) tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture(): with graph_capture(device=device):
# two tp groups can communicate independently # two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]: if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor)
......
...@@ -920,7 +920,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent: ...@@ -920,7 +920,7 @@ def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
@contextmanager @contextmanager
def graph_capture(): def graph_capture(device: torch.device):
""" """
`graph_capture` is a context manager which should surround the code that `graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the is capturing the CUDA graph. Its main purpose is to ensure that the
...@@ -934,8 +934,9 @@ def graph_capture(): ...@@ -934,8 +934,9 @@ def graph_capture():
in order to explicitly distinguish the kernels to capture in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream. from other kernels possibly launched on background in the default stream.
""" """
with get_tp_group().graph_capture() as context, get_pp_group( context = GraphCaptureContext(torch.cuda.Stream(device=device))
).graph_capture(context): with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
context):
yield context yield context
......
...@@ -836,7 +836,7 @@ class GPUModelRunner: ...@@ -836,7 +836,7 @@ class GPUModelRunner:
# Trigger CUDA graph capture for specific shapes. # Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes # Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes. # can reuse the memory pool allocated for the large shapes.
with graph_capture(): with graph_capture(device=self.device):
for num_tokens in reversed(self.cudagraph_batch_sizes): for num_tokens in reversed(self.cudagraph_batch_sizes):
for _ in range(self.vllm_config.compilation_config. for _ in range(self.vllm_config.compilation_config.
cudagraph_num_of_warmups): cudagraph_num_of_warmups):
......
...@@ -1426,10 +1426,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1426,10 +1426,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
# Prepare dummy inputs. These will be reused for all batch sizes. # Prepare dummy inputs. These will be reused for all batch sizes.
max_batch_size = self.max_batchsize_to_capture max_batch_size = self.max_batchsize_to_capture
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_tokens = torch.zeros(max_batch_size,
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() dtype=torch.long,
device=self.device)
input_positions = torch.zeros(max_batch_size,
dtype=torch.long,
device=self.device)
if self.model_config.uses_mrope: if self.model_config.uses_mrope:
input_positions = torch.tile(input_positions, (3, 1)) input_positions = torch.tile(input_positions,
(3, 1)).cuda(device=self.device)
# Prepare dummy previous_hidden_states only if needed by the model. # Prepare dummy previous_hidden_states only if needed by the model.
# This is used by draft models such as EAGLE. # This is used by draft models such as EAGLE.
previous_hidden_states = None previous_hidden_states = None
...@@ -1448,8 +1453,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1448,8 +1453,8 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype, dtype=self.model_config.dtype,
device=self.device) device=self.device)
with self.attn_state.graph_capture( with self.attn_state.graph_capture(max_batch_size), graph_capture(
max_batch_size), graph_capture() as graph_capture_context: self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for virtual_engine in range( for virtual_engine in range(
...@@ -1549,10 +1554,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]): ...@@ -1549,10 +1554,12 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
""" """
# During the decode phase encoder_input_ids and encoder_positions are # During the decode phase encoder_input_ids and encoder_positions are
# unset. Do the same thing for graph capture. # unset. Do the same thing for graph capture.
capture_inputs["encoder_input_ids"] = torch.tensor( capture_inputs["encoder_input_ids"] = torch.tensor([],
[], dtype=torch.long).cuda() dtype=torch.long,
capture_inputs["encoder_positions"] = torch.tensor( device=self.device)
[], dtype=torch.long).cuda() capture_inputs["encoder_positions"] = torch.tensor([],
dtype=torch.long,
device=self.device)
@property @property
def vocab_size(self) -> int: def vocab_size(self) -> int:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment