Unverified Commit 96b6f475 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

Remove hardcoded `device="cuda" ` to support more devices (#2503)


Co-authored-by: default avatarJiang Li <jiang1.li@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent c410f5d0
......@@ -200,7 +200,7 @@ def _make_alibi_bias(
seq_len: int,
dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
......
......@@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
params_dtype: torch.dtype) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
......@@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module):
self.register_parameter(name, weight)
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=self.params_dtype))
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0})
else:
self.register_parameter("bias", None)
......@@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module):
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
......@@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module):
if bias:
self.bias = Parameter(
torch.empty(self.output_size,
device=torch.cuda.current_device(),
dtype=params_dtype))
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
......
......@@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
......@@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
......@@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
......
......@@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
......@@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase):
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
......@@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
......@@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty(
scale_and_zero_size,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
......
......@@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
......@@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty(
output_size,
self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
......@@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip():
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16)
else:
# NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None:
......
......@@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module):
# create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim))
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings,
dtype=torch.float,
device="cuda")
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
......@@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
# Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda")
t = torch.arange(max_len, dtype=torch.float)
t = t / self.scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq)
......@@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
(self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda")
t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
......@@ -297,9 +294,9 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") /
self.rotary_dim)
pos_freqs = self.base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
......@@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float,
device="cuda")) * self.extrapolation_factor
low, high, self.rotary_dim // 2,
dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq
......@@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale)
......
......@@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.weight, {
"parallel_dim": 0,
......@@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding):
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, {
"parallel_dim": 0,
......
......@@ -5,7 +5,7 @@ from typing import Optional, Type
import torch
import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
......@@ -38,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config)
......@@ -64,7 +65,7 @@ def get_model(model_config: ModelConfig,
with _set_default_torch_dtype(model_config.dtype):
# Create a model instance.
# The weights will be initialized as empty tensors.
with torch.device("cuda"):
with torch.device(device_config.device):
if getattr(model_class, "supports_lora", False):
model = model_class(model_config.hf_config, linear_method,
lora_config)
......
......@@ -228,7 +228,8 @@ def create_kv_caches_with_random(
device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str):
if cache_dtype == "auto":
......
......@@ -104,11 +104,13 @@ class CacheEngine:
size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype,
pin_memory=pin_memory,
device="cpu",
)
cpu_cache.append((key_blocks, value_blocks))
return cpu_cache
......
......@@ -5,7 +5,7 @@ import numpy as np
import torch
import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import (
......@@ -35,6 +35,7 @@ class ModelRunner:
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False,
......@@ -49,7 +50,10 @@ class ModelRunner:
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None)
self.device = torch.device(torch.cuda.current_device())
self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None
self.block_size = None # Set after initial profiling.
self.lora_manager = None
......@@ -72,7 +76,8 @@ class ModelRunner:
self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None:
self.model = get_model(self.model_config, self.lora_config)
self.model = get_model(self.model_config, self.device_config,
self.lora_config)
vocab_size = self.model.config.vocab_size
......@@ -182,22 +187,25 @@ class ModelRunner:
input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len,
pad=0,
dtype=torch.long)
dtype=torch.long,
device=self.device)
input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len,
pad=0,
dtype=torch.long)
dtype=torch.long,
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len,
pad=_PAD_SLOT_ID,
dtype=torch.long)
dtype=torch.long,
device=self.device)
lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping
]
context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int,
device='cuda')
device=self.device)
# Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad(
......@@ -205,15 +213,16 @@ class ModelRunner:
max_len=max_prompt_block_table_len,
pad=0,
dtype=torch.int,
device=self.device,
)
start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len,
max_prompt_len,
dtype=torch.long,
device='cuda')
device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long,
device='cuda')
device=self.device)
input_metadata = InputMetadata(
is_prompt=True,
......@@ -305,20 +314,20 @@ class ModelRunner:
max_len=1,
pad=0,
dtype=torch.long,
device="cuda")
device=self.device)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device="cuda")
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device="cuda")
device=self.device)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device="cuda")
device=self.device)
if use_captured_graph:
# The shape of graph_block_tables is
......@@ -327,7 +336,7 @@ class ModelRunner:
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device="cuda")
block_tables = torch.tensor(input_block_tables, device=self.device)
else:
max_block_table_len = max(
len(block_table) for block_table in block_tables)
......@@ -336,7 +345,7 @@ class ModelRunner:
max_len=max_block_table_len,
pad=0,
dtype=torch.int,
device="cuda",
device=self.device,
)
lora_index_mapping = [
......@@ -355,7 +364,8 @@ class ModelRunner:
use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype,
)
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests
return (input_tokens, input_positions, input_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
def _prepare_sample(
self,
......@@ -410,9 +420,13 @@ class ModelRunner:
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=not self.in_wsl)
categorized_sample_indices = {
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items()
}
......@@ -511,7 +525,8 @@ class ModelRunner:
perform_sampling=False,
)
return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping
return (input_tokens, input_positions, input_metadata,
sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode()
def execute_model(
......@@ -519,8 +534,9 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = (
self.prepare_input_tensors(seq_group_metadata_list))
(input_tokens, input_positions, input_metadata, sampling_metadata,
lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping)
......@@ -789,14 +805,10 @@ def _make_tensor_with_pad(
max_len: int,
pad: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
pin_memory: bool = False,
device: Optional[Union[str, torch.device]],
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
return torch.tensor(padded_x, dtype=dtype, device=device)
def _get_graph_batch_size(batch_size: int) -> int:
......@@ -808,6 +820,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
return (batch_size + 7) // 8 * 8
def _async_h2d(data: list, dtype, pin_memory):
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
return t.to(device="cuda", non_blocking=True)
def _async_h2d(
data: list,
dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
......@@ -6,8 +6,8 @@ from typing import Dict, List, Tuple, Set, Optional
import torch
import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig, LoRAConfig)
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict)
......@@ -33,6 +33,7 @@ class Worker:
model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int,
rank: int,
distributed_init_method: str,
......@@ -43,6 +44,7 @@ class Worker:
self.model_config = model_config
self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank
self.rank = rank
self.distributed_init_method = distributed_init_method
......@@ -54,6 +56,7 @@ class Worker:
self.model_runner = ModelRunner(model_config,
parallel_config,
scheduler_config,
device_config,
lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker)
......@@ -65,21 +68,24 @@ class Worker:
self.gpu_cache = None
def init_model(self) -> None:
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment