Commit 7e1d5e53 authored by zhuwenwen's avatar zhuwenwen
Browse files

merge v0.3.1

parents e3378b20 5f08050d
...@@ -104,11 +104,13 @@ class CacheEngine: ...@@ -104,11 +104,13 @@ class CacheEngine:
size=(self.num_cpu_blocks, *key_block_shape), size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device="cpu",
) )
value_blocks = torch.empty( value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape), size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device="cpu",
) )
cpu_cache.append((key_blocks, value_blocks)) cpu_cache.append((key_blocks, value_blocks))
return cpu_cache return cpu_cache
......
import contextlib
import time import time
from typing import Dict, List, Optional, Tuple, Set, Union from typing import Dict, List, Optional, Tuple, Set, Union
...@@ -5,11 +6,15 @@ import numpy as np ...@@ -5,11 +6,15 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig,
SchedulerConfig)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.parallel_state import (
with_cupy_nccl_for_all_reduce)
from vllm.model_executor.parallel_utils import custom_all_reduce from vllm.model_executor.parallel_utils import custom_all_reduce
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
...@@ -35,6 +40,7 @@ class ModelRunner: ...@@ -35,6 +40,7 @@ class ModelRunner:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
...@@ -49,7 +55,10 @@ class ModelRunner: ...@@ -49,7 +55,10 @@ class ModelRunner:
# FIXME(woosuk): This is a hack to make the tests work. Refactor this. # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window() self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None) if model_config is not None else None)
self.device = torch.device(torch.cuda.current_device()) self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None self.model = None
self.block_size = None # Set after initial profiling. self.block_size = None # Set after initial profiling.
self.lora_manager = None self.lora_manager = None
...@@ -72,16 +81,26 @@ class ModelRunner: ...@@ -72,16 +81,26 @@ class ModelRunner:
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config, self.lora_config) self.model = get_model(self.model_config, self.device_config,
self.lora_config)
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
if self.lora_config: if self.lora_config:
assert hasattr(
self.model, "supported_lora_modules"
) and self.model.supported_lora_modules, "Model does not support LoRA"
assert hasattr(
self.model,
"embedding_modules"), "Model does not have embedding_modules"
assert hasattr(self.model, "embedding_padding_modules"
), "Model does not have embedding_padding_modules"
self.lora_manager = LRUCacheWorkerLoRAManager( self.lora_manager = LRUCacheWorkerLoRAManager(
self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_seqs,
self.scheduler_config.max_num_batched_tokens + self.scheduler_config.max_num_batched_tokens +
self.scheduler_config.max_paddings, vocab_size, self.scheduler_config.max_paddings, vocab_size,
self.lora_config, self.device) self.lora_config, self.device, self.model.embedding_modules,
self.model.embedding_padding_modules)
self.model = self.lora_manager.create_lora_manager(self.model) self.model = self.lora_manager.create_lora_manager(self.model)
def set_block_size(self, block_size: int) -> None: def set_block_size(self, block_size: int) -> None:
...@@ -142,10 +161,10 @@ class ModelRunner: ...@@ -142,10 +161,10 @@ class ModelRunner:
if lora_id > 0: if lora_id > 0:
lora_requests.add(seq_group_metadata.lora_request) lora_requests.add(seq_group_metadata.lora_request)
lora_index_mapping.append([lora_id] * prompt_len) lora_index_mapping.append([lora_id] * (prompt_len - prefix_len))
lora_prompt_mapping.extend( lora_prompt_mapping.extend(
[lora_id] * [lora_id] *
(prompt_len (prompt_len - prefix_len
if seq_group_metadata.sampling_params.prompt_logprobs else 1)) if seq_group_metadata.sampling_params.prompt_logprobs else 1))
if seq_group_metadata.block_tables is None: if seq_group_metadata.block_tables is None:
...@@ -182,22 +201,25 @@ class ModelRunner: ...@@ -182,22 +201,25 @@ class ModelRunner:
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len, max_prompt_len,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len, max_prompt_len,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len, max_prompt_len,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long) dtype=torch.long,
device=self.device)
lora_index_mapping = [ lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0) _pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping for mapping in lora_index_mapping
] ]
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device='cuda') device=self.device)
# Prepare prefix block tables # Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad( block_tables = _make_tensor_with_pad(
...@@ -205,15 +227,16 @@ class ModelRunner: ...@@ -205,15 +227,16 @@ class ModelRunner:
max_len=max_prompt_block_table_len, max_len=max_prompt_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=self.device,
) )
start_loc_tensor = torch.arange(0, start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len, len(prompt_lens) * max_prompt_len,
max_prompt_len, max_prompt_len,
dtype=torch.long, dtype=torch.long,
device='cuda') device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens, prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long, dtype=torch.long,
device='cuda') device=self.device)
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=True, is_prompt=True,
...@@ -305,20 +328,20 @@ class ModelRunner: ...@@ -305,20 +328,20 @@ class ModelRunner:
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1, max_len=1,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device="cuda") device=self.device)
if use_captured_graph: if use_captured_graph:
# The shape of graph_block_tables is # The shape of graph_block_tables is
...@@ -327,7 +350,7 @@ class ModelRunner: ...@@ -327,7 +350,7 @@ class ModelRunner:
for i, block_table in enumerate(block_tables): for i, block_table in enumerate(block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device="cuda") block_tables = torch.tensor(input_block_tables, device=self.device)
else: else:
max_block_table_len = max( max_block_table_len = max(
len(block_table) for block_table in block_tables) len(block_table) for block_table in block_tables)
...@@ -336,7 +359,7 @@ class ModelRunner: ...@@ -336,7 +359,7 @@ class ModelRunner:
max_len=max_block_table_len, max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device="cuda", device=self.device,
) )
lora_index_mapping = [ lora_index_mapping = [
...@@ -355,7 +378,8 @@ class ModelRunner: ...@@ -355,7 +378,8 @@ class ModelRunner:
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests return (input_tokens, input_positions, input_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -410,9 +434,13 @@ class ModelRunner: ...@@ -410,9 +434,13 @@ class ModelRunner:
selected_token_indices = _async_h2d(selected_token_indices, selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=self.device,
pin_memory=not self.in_wsl) pin_memory=not self.in_wsl)
categorized_sample_indices = { categorized_sample_indices = {
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items() for t, seq_ids in categorized_sample_indices.items()
} }
...@@ -511,7 +539,8 @@ class ModelRunner: ...@@ -511,7 +539,8 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping return (input_tokens, input_positions, input_metadata,
sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -519,8 +548,9 @@ class ModelRunner: ...@@ -519,8 +548,9 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = ( (input_tokens, input_positions, input_metadata, sampling_metadata,
self.prepare_input_tensors(seq_group_metadata_list)) lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config: if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping) self.set_active_loras(lora_requests, lora_mapping)
...@@ -628,6 +658,10 @@ class ModelRunner: ...@@ -628,6 +658,10 @@ class ModelRunner:
@torch.inference_mode() @torch.inference_mode()
def capture_model(self, kv_caches: List[KVCache]) -> None: def capture_model(self, kv_caches: List[KVCache]) -> None:
# NOTE(woosuk): This is a hack to ensure that the NCCL backend is never
# deleted before the CUDA graphs.
self.cupy_nccl_backend = cupy_utils.get_nccl_backend()
assert not self.model_config.enforce_eager assert not self.model_config.enforce_eager
logger.info("Capturing the model for CUDA graphs. This may lead to " logger.info("Capturing the model for CUDA graphs. This may lead to "
"unexpected consequences if the model is not static. To " "unexpected consequences if the model is not static. To "
...@@ -656,9 +690,15 @@ class ModelRunner: ...@@ -656,9 +690,15 @@ class ModelRunner:
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
] ]
# NOTE: Capturing the largest batch size first may help reduce the # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
# memory usage of CUDA graph. # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or CuPy NCCL. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or CuPy NCCL if it is disabled or not supported.
with custom_all_reduce.capture(): with custom_all_reduce.capture():
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for batch_size in reversed(batch_size_capture_list): for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata. # Create dummy input_metadata.
input_metadata = InputMetadata( input_metadata = InputMetadata(
...@@ -697,6 +737,14 @@ class ModelRunner: ...@@ -697,6 +737,14 @@ class ModelRunner:
# This usually takes < 10 seconds. # This usually takes < 10 seconds.
logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.")
def __del__(self) -> None:
# Delete the CUDA graphs before deleting the CuPy NCCL communicator.
# NOTE(woosuk): This is necessary because otherwise deadlocks can
# happen.
# FIXME(woosuk): This is a bit hacky. Find a more robust solution.
self.graph_runners.clear()
self.cupy_nccl_backend = None
class CUDAGraphRunner: class CUDAGraphRunner:
...@@ -718,18 +766,8 @@ class CUDAGraphRunner: ...@@ -718,18 +766,8 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph. # Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the # This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune). # kernel launches for initial benchmarking (e.g., Triton autotune).
self.model( with _maybe_cupy_nccl():
input_ids, self.model(
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Capture the graph.
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool):
hidden_states = self.model(
input_ids, input_ids,
positions, positions,
kv_caches, kv_caches,
...@@ -737,6 +775,20 @@ class CUDAGraphRunner: ...@@ -737,6 +775,20 @@ class CUDAGraphRunner:
) )
torch.cuda.synchronize() torch.cuda.synchronize()
# Capture the graph.
# NOTE(woosuk): Python 3.8 does not support multi-line with statements.
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117
with _maybe_cupy_nccl():
hidden_states = self.model(
input_ids,
positions,
kv_caches,
input_metadata,
)
torch.cuda.synchronize()
# Save the input and output buffers. # Save the input and output buffers.
self.input_buffers = { self.input_buffers = {
"input_ids": input_ids, "input_ids": input_ids,
...@@ -779,6 +831,15 @@ class CUDAGraphRunner: ...@@ -779,6 +831,15 @@ class CUDAGraphRunner:
return self.forward(*args, **kwargs) return self.forward(*args, **kwargs)
@contextlib.contextmanager
def _maybe_cupy_nccl():
if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized():
with with_cupy_nccl_for_all_reduce():
yield
else:
yield
def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]:
assert len(x) <= max_len assert len(x) <= max_len
return x + [pad] * (max_len - len(x)) return x + [pad] * (max_len - len(x))
...@@ -789,14 +850,10 @@ def _make_tensor_with_pad( ...@@ -789,14 +850,10 @@ def _make_tensor_with_pad(
max_len: int, max_len: int,
pad: int, pad: int,
dtype: torch.dtype, dtype: torch.dtype,
device: Union[str, torch.device] = "cuda", device: Optional[Union[str, torch.device]],
pin_memory: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, return torch.tensor(padded_x, dtype=dtype, device=device)
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
def _get_graph_batch_size(batch_size: int) -> int: def _get_graph_batch_size(batch_size: int) -> int:
...@@ -808,6 +865,11 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -808,6 +865,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
return (batch_size + 7) // 8 * 8 return (batch_size + 7) // 8 * 8
def _async_h2d(data: list, dtype, pin_memory): def _async_h2d(
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) data: list,
return t.to(device="cuda", non_blocking=True) dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
...@@ -6,9 +6,10 @@ from typing import Dict, List, Tuple, Set, Optional ...@@ -6,9 +6,10 @@ from typing import Dict, List, Tuple, Set, Optional
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
SchedulerConfig, LoRAConfig) ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils import cupy_utils
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar
...@@ -18,6 +19,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata ...@@ -18,6 +19,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.worker.cache_engine import CacheEngine from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner from vllm.worker.model_runner import ModelRunner
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.utils import is_hip
class Worker: class Worker:
...@@ -33,6 +35,7 @@ class Worker: ...@@ -33,6 +35,7 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
...@@ -43,6 +46,7 @@ class Worker: ...@@ -43,6 +46,7 @@ class Worker:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
...@@ -54,6 +58,7 @@ class Worker: ...@@ -54,6 +58,7 @@ class Worker:
self.model_runner = ModelRunner(model_config, self.model_runner = ModelRunner(model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
device_config,
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
...@@ -64,25 +69,30 @@ class Worker: ...@@ -64,25 +69,30 @@ class Worker:
self.cache_events = None self.cache_events = None
self.gpu_cache = None self.gpu_cache = None
def init_model(self) -> None: def init_model(self, cupy_port: Optional[int] = None) -> None:
# torch.distributed.all_reduce does not free the input tensor until if self.device_config.device.type == "cuda":
# the synchronization point. This causes the memory usage to grow # torch.distributed.all_reduce does not free the input tensor until
# as the number of all_reduce calls increases. This env var disables # the synchronization point. This causes the memory usage to grow
# this behavior. # as the number of all_reduce calls increases. This env var disables
# Related issue: # this behavior.
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 # Related issue:
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) # This env var set by Ray causes exceptions with graph building.
self.device = torch.device(f"cuda:{self.local_rank}") os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
torch.cuda.set_device(self.device) self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
_check_if_gpu_supports_dtype(self.model_config.dtype)
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank, init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method) cupy_port, self.distributed_init_method)
if not self.parallel_config.disable_custom_all_reduce: if not self.parallel_config.disable_custom_all_reduce:
init_custom_ar() init_custom_ar()
# Initialize the model. # Initialize the model.
...@@ -119,7 +129,9 @@ class Worker: ...@@ -119,7 +129,9 @@ class Worker:
# profiled peak memory. # profiled peak memory.
torch.cuda.synchronize() torch.cuda.synchronize()
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
peak_memory = total_gpu_memory - free_gpu_memory # NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
peak_memory = self.init_gpu_memory - free_gpu_memory
cache_block_size = CacheEngine.get_cache_block_size( cache_block_size = CacheEngine.get_cache_block_size(
block_size, cache_dtype, self.model_config, self.parallel_config) block_size, cache_dtype, self.model_config, self.parallel_config)
...@@ -227,6 +239,7 @@ class Worker: ...@@ -227,6 +239,7 @@ class Worker:
def init_distributed_environment( def init_distributed_environment(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
rank: int, rank: int,
cupy_port: Optional[int],
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
) -> None: ) -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
...@@ -249,8 +262,29 @@ def init_distributed_environment( ...@@ -249,8 +262,29 @@ def init_distributed_environment(
init_method=distributed_init_method, init_method=distributed_init_method,
) )
if cupy_utils.is_initialized():
cupy_world_size = cupy_utils.get_world_size()
if cupy_world_size != parallel_config.world_size:
raise RuntimeError(
"cupy.distributed is already initialized but the cupy world "
"size does not match parallel_config.world_size "
f"({cupy_world_size} vs. {parallel_config.world_size}).")
elif (parallel_config.world_size > 1 and cupy_port is not None
and not is_hip()):
# NOTE(woosuk): We don't initialize CuPy process group when world size
# is 1.
# TODO(woosuk): Support multi-node connection.
cupy_utils.init_process_group(
world_size=parallel_config.world_size,
rank=rank,
host="localhost",
port=cupy_port,
)
# A small all_reduce for warmup. # A small all_reduce for warmup.
torch.distributed.all_reduce(torch.zeros(1).cuda()) torch.distributed.all_reduce(torch.zeros(1).cuda())
if cupy_utils.is_initialized():
cupy_utils.all_reduce(torch.zeros(1).cuda())
ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, ensure_model_parallel_initialized(parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment