Unverified Commit 96b6f475 authored by Kunshang Ji's avatar Kunshang Ji Committed by GitHub
Browse files

Remove hardcoded `device="cuda" ` to support more devices (#2503)


Co-authored-by: default avatarJiang Li <jiang1.li@intel.com>
Co-authored-by: default avatarKunshang Ji <kunshang.ji@intel.com>
parent c410f5d0
...@@ -200,7 +200,7 @@ def _make_alibi_bias( ...@@ -200,7 +200,7 @@ def _make_alibi_bias(
seq_len: int, seq_len: int,
dtype: torch.dtype, dtype: torch.dtype,
) -> LowerTriangularMaskWithTensorBias: ) -> LowerTriangularMaskWithTensorBias:
bias = torch.arange(seq_len, dtype=dtype, device="cuda") bias = torch.arange(seq_len, dtype=dtype)
# NOTE(zhuohan): HF uses # NOTE(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)` # `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but # here. We find that both biases give the same results, but
......
...@@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -54,7 +54,6 @@ class UnquantizedLinearMethod(LinearMethodBase):
params_dtype: torch.dtype) -> Dict[str, Any]: params_dtype: torch.dtype) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition, weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition, input_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
...@@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -113,9 +112,7 @@ class ReplicatedLinear(torch.nn.Module):
self.register_parameter(name, weight) self.register_parameter(name, weight)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, torch.empty(self.output_size, dtype=self.params_dtype))
device=torch.cuda.current_device(),
dtype=self.params_dtype))
set_weight_attrs(self.bias, {"output_dim": 0}) set_weight_attrs(self.bias, {"output_dim": 0})
else: else:
self.register_parameter("bias", None) self.register_parameter("bias", None)
...@@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -183,7 +180,6 @@ class ColumnParallelLinear(torch.nn.Module):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size_per_partition, torch.empty(self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"output_dim": 0, "output_dim": 0,
...@@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module): ...@@ -509,9 +505,7 @@ class RowParallelLinear(torch.nn.Module):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.output_size, torch.empty(self.output_size, dtype=params_dtype))
device=torch.cuda.current_device(),
dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"output_dim": 0, "output_dim": 0,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
......
...@@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -96,7 +96,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition, input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -112,7 +111,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -128,7 +126,6 @@ class AWQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
......
...@@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -127,7 +127,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -145,7 +144,6 @@ class GPTQLinearMethod(LinearMethodBase):
i // self.quant_config.group_size i // self.quant_config.group_size
for i in range(input_size_per_partition) for i in range(input_size_per_partition)
], ],
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -156,7 +154,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
scale_and_zero_size, scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -172,7 +169,6 @@ class GPTQLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
scale_and_zero_size, scale_and_zero_size,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
......
...@@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -80,7 +80,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
input_size_per_partition // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition, output_size_per_partition,
device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
requires_grad=False, requires_grad=False,
...@@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -96,7 +95,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
torch.empty( torch.empty(
output_size, output_size,
self.quant_config.weight_bits**2, self.quant_config.weight_bits**2,
device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
requires_grad=False, requires_grad=False,
...@@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -118,12 +116,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (qweight.shape[-1], ) out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1]) reshaped_x = x.reshape(-1, x.shape[-1])
if is_hip(): if is_hip():
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float) out_f = torch.zeros(out_shape, dtype=torch.float)
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
out = out_f.to(dtype=torch.float16) out = out_f.to(dtype=torch.float16)
else: else:
# NOTE: The output tensor should be zero-initialized. # NOTE: The output tensor should be zero-initialized.
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) out = torch.zeros(out_shape, dtype=torch.float16)
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
if bias is not None: if bias is not None:
......
...@@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module): ...@@ -77,16 +77,13 @@ class RotaryEmbedding(nn.Module):
# create the cache on GPU for faster initialization. This may cause # create the cache on GPU for faster initialization. This may cause
# a slight numerical difference between the HF implementation and ours. # a slight numerical difference between the HF implementation and ours.
inv_freq = 1.0 / (base**(torch.arange( inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
self.rotary_dim))
return inv_freq return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache.""" """Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base) inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, t = torch.arange(self.max_position_embeddings, dtype=torch.float)
dtype=torch.float,
device="cuda")
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
...@@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -174,7 +171,7 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
# Thus, the maximum length after applying the rope scaling is # Thus, the maximum length after applying the rope scaling is
# self.max_position_embeddings * self.scaling_factor. # self.max_position_embeddings * self.scaling_factor.
max_len = self.max_position_embeddings * self.scaling_factor max_len = self.max_position_embeddings * self.scaling_factor
t = torch.arange(max_len, dtype=torch.float, device="cuda") t = torch.arange(max_len, dtype=torch.float)
t = t / self.scaling_factor t = t / self.scaling_factor
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
...@@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -214,7 +211,7 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
(self.scaling_factor - 1))**(self.rotary_dim / (self.scaling_factor - 1))**(self.rotary_dim /
(self.rotary_dim - 2)) (self.rotary_dim - 2))
inv_freq = self._compute_inv_freq(base) inv_freq = self._compute_inv_freq(base)
t = torch.arange(max_len, dtype=torch.float, device="cuda") t = torch.arange(max_len, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() cos = freqs.cos()
...@@ -297,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -297,8 +294,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style) is_neox_style)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
pos_freqs = self.base**(torch.arange( pos_freqs = self.base**(
0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / torch.arange(0, self.rotary_dim, 2, dtype=torch.float) /
self.rotary_dim) self.rotary_dim)
inv_freq_extrapolation = 1.0 / pos_freqs inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)
...@@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -308,8 +305,8 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
self.max_position_embeddings) self.max_position_embeddings)
# Get n-d rotational scaling corrected for extrapolation # Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (1 - _yarn_linear_ramp_mask( inv_freq_mask = (1 - _yarn_linear_ramp_mask(
low, high, self.rotary_dim // 2, dtype=torch.float, low, high, self.rotary_dim // 2,
device="cuda")) * self.extrapolation_factor dtype=torch.float)) * self.extrapolation_factor
inv_freq = inv_freq_interpolation * ( inv_freq = inv_freq_interpolation * (
1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask
return inv_freq return inv_freq
...@@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -317,7 +314,6 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor) inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(self.max_position_embeddings * self.scaling_factor, t = torch.arange(self.max_position_embeddings * self.scaling_factor,
device="cuda",
dtype=torch.float32) dtype=torch.float32)
freqs = torch.einsum("i,j -> ij", t, inv_freq) freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = (freqs.cos() * self.mscale) cos = (freqs.cos() * self.mscale)
......
...@@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module): ...@@ -77,7 +77,6 @@ class VocabParallelEmbedding(torch.nn.Module):
self.weight = Parameter( self.weight = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
self.embedding_dim, self.embedding_dim,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.weight, { set_weight_attrs(self.weight, {
"parallel_dim": 0, "parallel_dim": 0,
...@@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -139,7 +138,6 @@ class ParallelLMHead(VocabParallelEmbedding):
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition, torch.empty(self.num_embeddings_per_partition,
device=torch.cuda.current_device(),
dtype=params_dtype)) dtype=params_dtype))
set_weight_attrs(self.bias, { set_weight_attrs(self.bias, {
"parallel_dim": 0, "parallel_dim": 0,
......
...@@ -5,7 +5,7 @@ from typing import Optional, Type ...@@ -5,7 +5,7 @@ from typing import Optional, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig from vllm.config import DeviceConfig, ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
...@@ -38,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: ...@@ -38,6 +38,7 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
def get_model(model_config: ModelConfig, def get_model(model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module: lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config) model_class = _get_model_architecture(model_config)
...@@ -64,7 +65,7 @@ def get_model(model_config: ModelConfig, ...@@ -64,7 +65,7 @@ def get_model(model_config: ModelConfig,
with _set_default_torch_dtype(model_config.dtype): with _set_default_torch_dtype(model_config.dtype):
# Create a model instance. # Create a model instance.
# The weights will be initialized as empty tensors. # The weights will be initialized as empty tensors.
with torch.device("cuda"): with torch.device(device_config.device):
if getattr(model_class, "supports_lora", False): if getattr(model_class, "supports_lora", False):
model = model_class(model_config.hf_config, linear_method, model = model_class(model_config.hf_config, linear_method,
lora_config) lora_config)
......
...@@ -228,6 +228,7 @@ def create_kv_caches_with_random( ...@@ -228,6 +228,7 @@ def create_kv_caches_with_random(
device: Optional[str] = "cuda", device: Optional[str] = "cuda",
) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
if isinstance(cache_dtype, str): if isinstance(cache_dtype, str):
......
...@@ -104,11 +104,13 @@ class CacheEngine: ...@@ -104,11 +104,13 @@ class CacheEngine:
size=(self.num_cpu_blocks, *key_block_shape), size=(self.num_cpu_blocks, *key_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device="cpu",
) )
value_blocks = torch.empty( value_blocks = torch.empty(
size=(self.num_cpu_blocks, *value_block_shape), size=(self.num_cpu_blocks, *value_block_shape),
dtype=self.dtype, dtype=self.dtype,
pin_memory=pin_memory, pin_memory=pin_memory,
device="cpu",
) )
cpu_cache.append((key_blocks, value_blocks)) cpu_cache.append((key_blocks, value_blocks))
return cpu_cache return cpu_cache
......
...@@ -5,7 +5,7 @@ import numpy as np ...@@ -5,7 +5,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig from vllm.config import DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
...@@ -35,6 +35,7 @@ class ModelRunner: ...@@ -35,6 +35,7 @@ class ModelRunner:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig], lora_config: Optional[LoRAConfig],
kv_cache_dtype: Optional[str] = "auto", kv_cache_dtype: Optional[str] = "auto",
is_driver_worker: bool = False, is_driver_worker: bool = False,
...@@ -49,7 +50,10 @@ class ModelRunner: ...@@ -49,7 +50,10 @@ class ModelRunner:
# FIXME(woosuk): This is a hack to make the tests work. Refactor this. # FIXME(woosuk): This is a hack to make the tests work. Refactor this.
self.sliding_window = (model_config.get_sliding_window() self.sliding_window = (model_config.get_sliding_window()
if model_config is not None else None) if model_config is not None else None)
self.device = torch.device(torch.cuda.current_device()) self.device_config = (device_config
if device_config is not None else DeviceConfig())
self.device = self.device_config.device
self.model = None self.model = None
self.block_size = None # Set after initial profiling. self.block_size = None # Set after initial profiling.
self.lora_manager = None self.lora_manager = None
...@@ -72,7 +76,8 @@ class ModelRunner: ...@@ -72,7 +76,8 @@ class ModelRunner:
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
def load_model(self) -> None: def load_model(self) -> None:
self.model = get_model(self.model_config, self.lora_config) self.model = get_model(self.model_config, self.device_config,
self.lora_config)
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
...@@ -182,22 +187,25 @@ class ModelRunner: ...@@ -182,22 +187,25 @@ class ModelRunner:
input_tokens = _make_tensor_with_pad(input_tokens, input_tokens = _make_tensor_with_pad(input_tokens,
max_prompt_len, max_prompt_len,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_prompt_len, max_prompt_len,
pad=0, pad=0,
dtype=torch.long) dtype=torch.long,
device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_prompt_len, max_prompt_len,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long) dtype=torch.long,
device=self.device)
lora_index_mapping = [ lora_index_mapping = [
_pad_to_max(mapping, max_prompt_len, pad=0) _pad_to_max(mapping, max_prompt_len, pad=0)
for mapping in lora_index_mapping for mapping in lora_index_mapping
] ]
context_lens_tensor = torch.tensor(context_lens, context_lens_tensor = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device='cuda') device=self.device)
# Prepare prefix block tables # Prepare prefix block tables
max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) max_prompt_block_table_len = max(len(t) for t in prefix_block_tables)
block_tables = _make_tensor_with_pad( block_tables = _make_tensor_with_pad(
...@@ -205,15 +213,16 @@ class ModelRunner: ...@@ -205,15 +213,16 @@ class ModelRunner:
max_len=max_prompt_block_table_len, max_len=max_prompt_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device=self.device,
) )
start_loc_tensor = torch.arange(0, start_loc_tensor = torch.arange(0,
len(prompt_lens) * max_prompt_len, len(prompt_lens) * max_prompt_len,
max_prompt_len, max_prompt_len,
dtype=torch.long, dtype=torch.long,
device='cuda') device=self.device)
prompt_lens_tensor = torch.tensor(prompt_lens, prompt_lens_tensor = torch.tensor(prompt_lens,
dtype=torch.long, dtype=torch.long,
device='cuda') device=self.device)
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=True, is_prompt=True,
...@@ -305,20 +314,20 @@ class ModelRunner: ...@@ -305,20 +314,20 @@ class ModelRunner:
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
input_positions = _make_tensor_with_pad(input_positions, input_positions = _make_tensor_with_pad(input_positions,
max_len=1, max_len=1,
pad=0, pad=0,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
slot_mapping = _make_tensor_with_pad(slot_mapping, slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1, max_len=1,
pad=_PAD_SLOT_ID, pad=_PAD_SLOT_ID,
dtype=torch.long, dtype=torch.long,
device="cuda") device=self.device)
context_lens = torch.tensor(context_lens, context_lens = torch.tensor(context_lens,
dtype=torch.int, dtype=torch.int,
device="cuda") device=self.device)
if use_captured_graph: if use_captured_graph:
# The shape of graph_block_tables is # The shape of graph_block_tables is
...@@ -327,7 +336,7 @@ class ModelRunner: ...@@ -327,7 +336,7 @@ class ModelRunner:
for i, block_table in enumerate(block_tables): for i, block_table in enumerate(block_tables):
if block_table: if block_table:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device="cuda") block_tables = torch.tensor(input_block_tables, device=self.device)
else: else:
max_block_table_len = max( max_block_table_len = max(
len(block_table) for block_table in block_tables) len(block_table) for block_table in block_tables)
...@@ -336,7 +345,7 @@ class ModelRunner: ...@@ -336,7 +345,7 @@ class ModelRunner:
max_len=max_block_table_len, max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device="cuda", device=self.device,
) )
lora_index_mapping = [ lora_index_mapping = [
...@@ -355,7 +364,8 @@ class ModelRunner: ...@@ -355,7 +364,8 @@ class ModelRunner:
use_cuda_graph=use_captured_graph, use_cuda_graph=use_captured_graph,
kv_cache_dtype=self.kv_cache_dtype, kv_cache_dtype=self.kv_cache_dtype,
) )
return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests return (input_tokens, input_positions, input_metadata,
lora_index_mapping, lora_prompt_mapping, lora_requests)
def _prepare_sample( def _prepare_sample(
self, self,
...@@ -410,9 +420,13 @@ class ModelRunner: ...@@ -410,9 +420,13 @@ class ModelRunner:
selected_token_indices = _async_h2d(selected_token_indices, selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long, dtype=torch.long,
target_device=self.device,
pin_memory=not self.in_wsl) pin_memory=not self.in_wsl)
categorized_sample_indices = { categorized_sample_indices = {
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items() for t, seq_ids in categorized_sample_indices.items()
} }
...@@ -511,7 +525,8 @@ class ModelRunner: ...@@ -511,7 +525,8 @@ class ModelRunner:
perform_sampling=False, perform_sampling=False,
) )
return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping return (input_tokens, input_positions, input_metadata,
sampling_metadata, lora_requests, lora_mapping)
@torch.inference_mode() @torch.inference_mode()
def execute_model( def execute_model(
...@@ -519,8 +534,9 @@ class ModelRunner: ...@@ -519,8 +534,9 @@ class ModelRunner:
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = ( (input_tokens, input_positions, input_metadata, sampling_metadata,
self.prepare_input_tensors(seq_group_metadata_list)) lora_requests,
lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list)
if self.lora_config: if self.lora_config:
self.set_active_loras(lora_requests, lora_mapping) self.set_active_loras(lora_requests, lora_mapping)
...@@ -789,14 +805,10 @@ def _make_tensor_with_pad( ...@@ -789,14 +805,10 @@ def _make_tensor_with_pad(
max_len: int, max_len: int,
pad: int, pad: int,
dtype: torch.dtype, dtype: torch.dtype,
device: Union[str, torch.device] = "cuda", device: Optional[Union[str, torch.device]],
pin_memory: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, return torch.tensor(padded_x, dtype=dtype, device=device)
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
def _get_graph_batch_size(batch_size: int) -> int: def _get_graph_batch_size(batch_size: int) -> int:
...@@ -808,6 +820,11 @@ def _get_graph_batch_size(batch_size: int) -> int: ...@@ -808,6 +820,11 @@ def _get_graph_batch_size(batch_size: int) -> int:
return (batch_size + 7) // 8 * 8 return (batch_size + 7) // 8 * 8
def _async_h2d(data: list, dtype, pin_memory): def _async_h2d(
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) data: list,
return t.to(device="cuda", non_blocking=True) dtype: torch.dtype,
target_device: Union[str, torch.device],
pin_memory: bool,
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
...@@ -6,8 +6,8 @@ from typing import Dict, List, Tuple, Set, Optional ...@@ -6,8 +6,8 @@ from typing import Dict, List, Tuple, Set, Optional
import torch import torch
import torch.distributed import torch.distributed
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
SchedulerConfig, LoRAConfig) ParallelConfig, SchedulerConfig, LoRAConfig)
from vllm.model_executor import set_random_seed from vllm.model_executor import set_random_seed
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
broadcast_tensor_dict) broadcast_tensor_dict)
...@@ -33,6 +33,7 @@ class Worker: ...@@ -33,6 +33,7 @@ class Worker:
model_config: ModelConfig, model_config: ModelConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
local_rank: int, local_rank: int,
rank: int, rank: int,
distributed_init_method: str, distributed_init_method: str,
...@@ -43,6 +44,7 @@ class Worker: ...@@ -43,6 +44,7 @@ class Worker:
self.model_config = model_config self.model_config = model_config
self.parallel_config = parallel_config self.parallel_config = parallel_config
self.scheduler_config = scheduler_config self.scheduler_config = scheduler_config
self.device_config = device_config
self.local_rank = local_rank self.local_rank = local_rank
self.rank = rank self.rank = rank
self.distributed_init_method = distributed_init_method self.distributed_init_method = distributed_init_method
...@@ -54,6 +56,7 @@ class Worker: ...@@ -54,6 +56,7 @@ class Worker:
self.model_runner = ModelRunner(model_config, self.model_runner = ModelRunner(model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
device_config,
lora_config=self.lora_config, lora_config=self.lora_config,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
is_driver_worker=is_driver_worker) is_driver_worker=is_driver_worker)
...@@ -65,6 +68,7 @@ class Worker: ...@@ -65,6 +68,7 @@ class Worker:
self.gpu_cache = None self.gpu_cache = None
def init_model(self) -> None: def init_model(self) -> None:
if self.device_config.device.type == "cuda":
# torch.distributed.all_reduce does not free the input tensor until # torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow # the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables # as the number of all_reduce calls increases. This env var disables
...@@ -79,7 +83,9 @@ class Worker: ...@@ -79,7 +83,9 @@ class Worker:
torch.cuda.set_device(self.device) torch.cuda.set_device(self.device)
_check_if_gpu_supports_dtype(self.model_config.dtype) _check_if_gpu_supports_dtype(self.model_config.dtype)
else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
init_distributed_environment(self.parallel_config, self.rank, init_distributed_environment(self.parallel_config, self.rank,
self.distributed_init_method) self.distributed_init_method)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment