Commit fcffb7c8 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'vllm-v0.2.7-dtk23.10'

parents eb181638 4095d0db
...@@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32] ...@@ -14,6 +14,7 @@ BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
...@@ -24,6 +25,7 @@ SEEDS = [0] ...@@ -24,6 +25,7 @@ SEEDS = [0]
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_copy_blocks( def test_copy_blocks(
kv_cache_factory, kv_cache_factory,
...@@ -35,11 +37,12 @@ def test_copy_blocks( ...@@ -35,11 +37,12 @@ def test_copy_blocks(
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
# Generate random block mappings where each source block is mapped to two # Generate random block mappings where each source block is mapped to two
# destination blocks. # destination blocks.
assert 2 * num_mappings <= num_blocks assert 2 * num_mappings <= num_blocks
...@@ -56,7 +59,7 @@ def test_copy_blocks( ...@@ -56,7 +59,7 @@ def test_copy_blocks(
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, key_caches, value_caches = kv_cache_factory(num_blocks, block_size,
num_layers, num_heads, num_layers, num_heads,
head_size, dtype, seed) head_size, dtype, seed, gpu_id)
# Clone the KV caches. # Clone the KV caches.
cloned_key_caches = [key_cache.clone() for key_cache in key_caches] cloned_key_caches = [key_cache.clone() for key_cache in key_caches]
...@@ -88,6 +91,7 @@ def test_copy_blocks( ...@@ -88,6 +91,7 @@ def test_copy_blocks(
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) @pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_reshape_and_cache( def test_reshape_and_cache(
kv_cache_factory, kv_cache_factory,
...@@ -98,28 +102,29 @@ def test_reshape_and_cache( ...@@ -98,28 +102,29 @@ def test_reshape_and_cache(
num_blocks: int, num_blocks: int,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int,
) -> None: ) -> None:
random.seed(seed) random.seed(seed)
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
# Create a random slot mapping. # Create a random slot mapping.
num_slots = block_size * num_blocks num_slots = block_size * num_blocks
slot_mapping = random.sample(range(num_slots), num_tokens) slot_mapping = random.sample(range(num_slots), num_tokens)
slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda") slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id)
qkv = torch.randn(num_tokens, qkv = torch.randn(num_tokens,
3, 3,
num_heads, num_heads,
head_size, head_size,
dtype=dtype, dtype=dtype,
device="cuda") device=gpu_id)
_, key, value = qkv.unbind(dim=1) _, key, value = qkv.unbind(dim=1)
# Create the KV caches. # Create the KV caches.
key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1,
num_heads, head_size, dtype, num_heads, head_size, dtype,
seed) seed, gpu_id)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Clone the KV caches. # Clone the KV caches.
......
...@@ -8,6 +8,7 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing ...@@ -8,6 +8,7 @@ NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
ADD_RESIDUAL = [False, True] ADD_RESIDUAL = [False, True]
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
...@@ -15,6 +16,7 @@ SEEDS = [0] ...@@ -15,6 +16,7 @@ SEEDS = [0]
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL) @pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_rms_norm( def test_rms_norm(
num_tokens: int, num_tokens: int,
...@@ -22,14 +24,15 @@ def test_rms_norm( ...@@ -22,14 +24,15 @@ def test_rms_norm(
add_residual: bool, add_residual: bool,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int,
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
layer = RMSNorm(hidden_size).to(dtype).cuda() layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id)
layer.weight.data.normal_(mean=1.0, std=0.1) layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size) scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda") x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id)
x *= scale x *= scale
residual = torch.randn_like(x) * scale if add_residual else None residual = torch.randn_like(x) * scale if add_residual else None
......
...@@ -13,6 +13,7 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing ...@@ -13,6 +13,7 @@ NUM_HEADS = [7, 17] # Arbitrary values for testing
BATCH_SIZES = [1, 5] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
...@@ -23,6 +24,7 @@ SEEDS = [0] ...@@ -23,6 +24,7 @@ SEEDS = [0]
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", DEVICES)
@torch.inference_mode() @torch.inference_mode()
def test_rotary_embedding( def test_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
...@@ -33,6 +35,7 @@ def test_rotary_embedding( ...@@ -33,6 +35,7 @@ def test_rotary_embedding(
rotary_dim: Optional[int], rotary_dim: Optional[int],
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
device: int,
max_position: int = 8192, max_position: int = 8192,
base: int = 10000, base: int = 10000,
) -> None: ) -> None:
...@@ -40,20 +43,20 @@ def test_rotary_embedding( ...@@ -40,20 +43,20 @@ def test_rotary_embedding(
rotary_dim = head_size rotary_dim = head_size
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
gpu_id = f"cuda:{device}"
if rotary_dim is None: if rotary_dim is None:
rotary_dim = head_size rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype).cuda() rope = rope.to(dtype=dtype, device=gpu_id)
positions = torch.randint(0, positions = torch.randint(0,
max_position, (batch_size, seq_len), max_position, (batch_size, seq_len),
device="cuda") device=gpu_id)
query = torch.randn(batch_size, query = torch.randn(batch_size,
seq_len, seq_len,
num_heads * head_size, num_heads * head_size,
dtype=dtype, dtype=dtype,
device="cuda") device=gpu_id)
key = torch.randn_like(query) key = torch.randn_like(query)
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
......
...@@ -33,8 +33,9 @@ def test_prepare_prompt(): ...@@ -33,8 +33,9 @@ def test_prepare_prompt():
expected_selected_token_indices.append(selected_token_start_idx + expected_selected_token_indices.append(selected_token_start_idx +
prompt_len - 1) prompt_len - 1)
selected_token_start_idx += max_seq_len selected_token_start_idx += max_seq_len
input_tokens, input_positions, _ = model_runner._prepare_prompt( input_tokens, input_positions, _, return_prompt_lens = (
seq_group_metadata_list) model_runner._prepare_prompt(seq_group_metadata_list))
assert return_prompt_lens == prompt_lens
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens) prompt_lens)
assert input_tokens.shape == (batch_size, max_seq_len) assert input_tokens.shape == (batch_size, max_seq_len)
......
...@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput ...@@ -9,7 +9,7 @@ from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__ from vllm.version import __dcu_version__
__version__ = "0.2.6" __version__ = "0.2.7"
__all__ = [ __all__ = [
"LLM", "LLM",
......
...@@ -181,12 +181,6 @@ class ModelConfig: ...@@ -181,12 +181,6 @@ class ModelConfig:
self.max_context_len_to_capture = self.max_model_len self.max_context_len_to_capture = self.max_model_len
self.max_context_len_to_capture = min(self.max_context_len_to_capture, self.max_context_len_to_capture = min(self.max_context_len_to_capture,
self.max_model_len) self.max_model_len)
if (self.quantization in ["gptq", "squeezellm"]
and not self.enforce_eager):
# Related issue: https://github.com/vllm-project/vllm/issues/2147
logger.warning(f"{self.quantization} does not support CUDA graph "
"yet. Disabling CUDA graph.")
self.enforce_eager = True
def verify_with_parallel_config( def verify_with_parallel_config(
self, self,
......
...@@ -183,50 +183,54 @@ class _AsyncLLMEngine(LLMEngine): ...@@ -183,50 +183,54 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
return ignored
if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
output = await self._run_workers_async( all_outputs = await self._run_workers_async(
"execute_model", "execute_model",
seq_group_metadata_list=seq_group_metadata_list, driver_kwargs={
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, "seq_group_metadata_list": seq_group_metadata_list,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
blocks_to_copy=scheduler_outputs.blocks_to_copy, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
) "blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
else:
output = []
return self._process_model_outputs(output, scheduler_outputs) + ignored return self._process_model_outputs(output, scheduler_outputs)
async def _run_workers_async( async def _run_workers_async(
self, self,
method: str, method: str,
*args, *args,
get_all_outputs: bool = False, driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
coros = [] coros = []
for worker in self.workers:
if self.parallel_config.worker_use_ray: if driver_args is None:
coros.append( driver_args = args
worker.execute_method.remote(method, *args, **kwargs)) if driver_kwargs is None:
else: driver_kwargs = kwargs
executor = getattr(worker, method)
# Run the driver worker asynchronously.
driver_executor = getattr(self.driver_worker, method)
coros.append(asyncio.get_event_loop().run_in_executor( coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(executor, *args, **kwargs))) None, partial(driver_executor, *driver_args, **driver_kwargs)))
all_outputs = await asyncio.gather(*coros) # Run the ray workers asynchronously.
for worker in self.workers:
coros.append(worker.execute_method.remote(method, *args, **kwargs))
if get_all_outputs: all_outputs = await asyncio.gather(*coros)
return all_outputs return all_outputs
# Make sure all workers have the same results.
output = all_outputs[0]
for other_output in all_outputs[1:]:
assert output == other_output
return output
class AsyncLLMEngine: class AsyncLLMEngine:
"""An asynchronous wrapper for LLMEngine. """An asynchronous wrapper for LLMEngine.
...@@ -490,13 +494,12 @@ class AsyncLLMEngine: ...@@ -490,13 +494,12 @@ class AsyncLLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster( placement_group = initialize_cluster(parallel_config,
parallel_config, engine_args.engine_use_ray) engine_args.engine_use_ray)
# Create the async LLM engine. # Create the async LLM engine.
engine = cls(parallel_config.worker_use_ray, engine = cls(parallel_config.worker_use_ray,
engine_args.engine_use_ray, engine_args.engine_use_ray,
*engine_configs, *engine_configs,
distributed_init_method,
placement_group, placement_group,
log_requests=not engine_args.disable_log_requests, log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats, log_stats=not engine_args.disable_log_stats,
......
import copy import copy
from collections import defaultdict
import os import os
import time import time
from functools import partial from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union Union)
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig) SchedulerConfig)
...@@ -14,14 +15,12 @@ from vllm.logger import init_logger ...@@ -14,14 +15,12 @@ from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup,
SequenceGroupMetadata, SequenceGroupOutput, SequenceGroupOutput, SequenceOutput, SequenceStatus)
SequenceOutput, SequenceStatus)
from vllm.transformers_utils.tokenizer import (detokenize_incrementally, from vllm.transformers_utils.tokenizer import (detokenize_incrementally,
get_tokenizer) get_tokenizer)
from vllm.utils import Counter from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port
if ray: if ray:
from ray.air.util.torch_dist import init_torch_dist_process_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -54,8 +53,6 @@ class LLMEngine: ...@@ -54,8 +53,6 @@ class LLMEngine:
management. management.
parallel_config: The configuration related to distributed execution. parallel_config: The configuration related to distributed execution.
scheduler_config: The configuration related to the request scheduler. scheduler_config: The configuration related to the request scheduler.
distributed_init_method: The initialization method for distributed
execution. See `torch.distributed.init_process_group` for details.
placement_group: Ray placement group for distributed execution. placement_group: Ray placement group for distributed execution.
Required for distributed execution. Required for distributed execution.
log_stats: Whether to log statistics. log_stats: Whether to log statistics.
...@@ -67,7 +64,6 @@ class LLMEngine: ...@@ -67,7 +64,6 @@ class LLMEngine:
cache_config: CacheConfig, cache_config: CacheConfig,
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig, scheduler_config: SchedulerConfig,
distributed_init_method: str,
placement_group: Optional["PlacementGroup"], placement_group: Optional["PlacementGroup"],
log_stats: bool, log_stats: bool,
) -> None: ) -> None:
...@@ -112,7 +108,7 @@ class LLMEngine: ...@@ -112,7 +108,7 @@ class LLMEngine:
os.environ["RAY_USAGE_STATS_ENABLED"] = "0" os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
self._init_workers_ray(placement_group) self._init_workers_ray(placement_group)
else: else:
self._init_workers(distributed_init_method) self._init_workers()
# Profile the memory usage and initialize the cache. # Profile the memory usage and initialize the cache.
self._init_cache() self._init_cache()
...@@ -127,7 +123,7 @@ class LLMEngine: ...@@ -127,7 +123,7 @@ class LLMEngine:
# List of (timestamp, num_tokens) # List of (timestamp, num_tokens)
self.num_generation_tokens: List[Tuple[float, int]] = [] self.num_generation_tokens: List[Tuple[float, int]] = []
def _init_workers(self, distributed_init_method: str): def _init_workers(self):
# Lazy import the Worker to avoid importing torch.cuda/xformers # Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker # before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker from vllm.worker.worker import Worker
...@@ -136,70 +132,122 @@ class LLMEngine: ...@@ -136,70 +132,122 @@ class LLMEngine:
"Ray is required if parallel_config.world_size > 1.") "Ray is required if parallel_config.world_size > 1.")
self.workers: List[Worker] = [] self.workers: List[Worker] = []
worker = Worker( distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}"
self.driver_worker = Worker(
self.model_config, self.model_config,
self.parallel_config, self.parallel_config,
self.scheduler_config, self.scheduler_config,
0, local_rank=0,
distributed_init_method, rank=0,
) distributed_init_method=distributed_init_method,
self.workers.append(worker) is_driver_worker=True,
self._run_workers(
"init_model",
get_all_outputs=True,
)
self._run_workers(
"load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers,
) )
self._run_workers("init_model")
self._run_workers("load_model")
def _init_workers_ray(self, placement_group: "PlacementGroup", def _init_workers_ray(self, placement_group: "PlacementGroup",
**ray_remote_kwargs): **ray_remote_kwargs):
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
self.workers: List[Worker] = []
for bundle in placement_group.bundle_specs:
if not bundle.get("GPU", 0):
continue
if self.parallel_config.tensor_parallel_size == 1: if self.parallel_config.tensor_parallel_size == 1:
num_gpus = self.cache_config.gpu_memory_utilization num_gpus = self.cache_config.gpu_memory_utilization
else: else:
num_gpus = 1 num_gpus = 1
self.driver_dummy_worker: RayWorkerVllm = None
self.workers: List[RayWorkerVllm] = []
driver_ip = get_ip()
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0):
continue
scheduling_strategy = PlacementGroupSchedulingStrategy(
placement_group=placement_group,
placement_group_capture_child_tasks=True,
placement_group_bundle_index=bundle_id,
)
worker = ray.remote( worker = ray.remote(
num_cpus=0, num_cpus=0,
num_gpus=num_gpus, num_gpus=num_gpus,
scheduling_strategy=PlacementGroupSchedulingStrategy( scheduling_strategy=scheduling_strategy,
placement_group=placement_group,
placement_group_capture_child_tasks=True),
**ray_remote_kwargs, **ray_remote_kwargs,
)(RayWorkerVllm).remote(self.model_config.trust_remote_code) )(RayWorkerVllm).remote(self.model_config.trust_remote_code)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:
# If the worker is on the same node as the driver, we use it
# as the resource holder for the driver process.
self.driver_dummy_worker = worker
else:
self.workers.append(worker) self.workers.append(worker)
if self.driver_dummy_worker is None:
raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a "
"GPU node.")
driver_node_id, driver_gpu_ids = ray.get(
self.driver_dummy_worker.get_node_and_gpu_ids.remote())
worker_node_and_gpu_ids = ray.get(
[worker.get_node_and_gpu_ids.remote() for worker in self.workers])
node_workers = defaultdict(list)
node_gpus = defaultdict(list)
node_workers[driver_node_id].append(0)
node_gpus[driver_node_id].extend(driver_gpu_ids)
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids,
start=1):
node_workers[node_id].append(i)
node_gpus[node_id].extend(gpu_ids)
for node_id, gpu_ids in node_gpus.items():
node_gpus[node_id] = sorted(gpu_ids)
# Set CUDA_VISIBLE_DEVICES for the driver.
set_cuda_visible_devices(node_gpus[driver_node_id])
for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids):
worker.set_cuda_visible_devices.remote(node_gpus[node_id])
distributed_init_method = f"tcp://{driver_ip}:{get_open_port()}"
# Lazy import the Worker to avoid importing torch.cuda/xformers
# before CUDA_VISIBLE_DEVICES is set in the Worker
from vllm.worker.worker import Worker
# Initialize torch distributed process group for the workers. # Initialize torch distributed process group for the workers.
init_torch_dist_process_group(self.workers, backend="nccl")
model_config = copy.deepcopy(self.model_config) model_config = copy.deepcopy(self.model_config)
parallel_config = copy.deepcopy(self.parallel_config) parallel_config = copy.deepcopy(self.parallel_config)
scheduler_config = copy.deepcopy(self.scheduler_config) scheduler_config = copy.deepcopy(self.scheduler_config)
self._run_workers("init_worker",
get_all_outputs=True, for rank, (worker, (node_id,
worker_init_fn=lambda: Worker( _)) in enumerate(zip(self.workers,
worker_node_and_gpu_ids),
start=1):
local_rank = node_workers[node_id].index(rank)
worker.init_worker.remote(
lambda rank=rank, local_rank=local_rank: Worker(
model_config, model_config,
parallel_config, parallel_config,
scheduler_config, scheduler_config,
None, local_rank,
None, rank,
distributed_init_method,
)) ))
self._run_workers(
"init_model", driver_rank = 0
get_all_outputs=True, driver_local_rank = node_workers[driver_node_id].index(driver_rank)
self.driver_worker = Worker(
model_config,
parallel_config,
scheduler_config,
driver_local_rank,
driver_rank,
distributed_init_method,
is_driver_worker=True,
) )
self._run_workers("init_model")
self._run_workers( self._run_workers(
"load_model", "load_model",
get_all_outputs=True,
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers, max_parallel_loading_workers,
) )
...@@ -213,7 +261,6 @@ class LLMEngine: ...@@ -213,7 +261,6 @@ class LLMEngine:
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers( num_blocks = self._run_workers(
"profile_num_available_blocks", "profile_num_available_blocks",
get_all_outputs=True,
block_size=self.cache_config.block_size, block_size=self.cache_config.block_size,
gpu_memory_utilization=self.cache_config.gpu_memory_utilization, gpu_memory_utilization=self.cache_config.gpu_memory_utilization,
cpu_swap_space=self.cache_config.swap_space_bytes, cpu_swap_space=self.cache_config.swap_space_bytes,
...@@ -257,11 +304,9 @@ class LLMEngine: ...@@ -257,11 +304,9 @@ class LLMEngine:
engine_configs = engine_args.create_engine_configs() engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2] parallel_config = engine_configs[2]
# Initialize the cluster. # Initialize the cluster.
distributed_init_method, placement_group = initialize_cluster( placement_group = initialize_cluster(parallel_config)
parallel_config)
# Create the LLM engine. # Create the LLM engine.
engine = cls(*engine_configs, engine = cls(*engine_configs,
distributed_init_method,
placement_group, placement_group,
log_stats=not engine_args.disable_log_stats) log_stats=not engine_args.disable_log_stats)
return engine return engine
...@@ -328,16 +373,6 @@ class LLMEngine: ...@@ -328,16 +373,6 @@ class LLMEngine:
"""Returns True if there are unfinished requests.""" """Returns True if there are unfinished requests."""
return self.scheduler.has_unfinished_seqs() return self.scheduler.has_unfinished_seqs()
def _schedule(
self
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs,
List[RequestOutput]]:
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
return seq_group_metadata_list, scheduler_outputs, [
RequestOutput.from_seq_group(seq_group)
for seq_group in scheduler_outputs.ignored_seq_groups
]
def _check_beam_search_early_stopping( def _check_beam_search_early_stopping(
self, self,
early_stopping: Union[bool, str], early_stopping: Union[bool, str],
...@@ -586,18 +621,23 @@ class LLMEngine: ...@@ -586,18 +621,23 @@ class LLMEngine:
and updates the scheduler with the model outputs. Finally, it decodes and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. the sequences and returns the newly generated results.
""" """
seq_group_metadata_list, scheduler_outputs, ignored = self._schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
if scheduler_outputs.is_empty():
return ignored
if not scheduler_outputs.is_empty():
# Execute the model. # Execute the model.
output = self._run_workers( all_outputs = self._run_workers(
"execute_model", "execute_model",
seq_group_metadata_list=seq_group_metadata_list, driver_kwargs={
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, "seq_group_metadata_list": seq_group_metadata_list,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
blocks_to_copy=scheduler_outputs.blocks_to_copy, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
) "blocks_to_copy": scheduler_outputs.blocks_to_copy,
})
# Only the driver worker returns the sampling results.
output = all_outputs[0]
else:
output = []
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
...@@ -725,53 +765,38 @@ class LLMEngine: ...@@ -725,53 +765,38 @@ class LLMEngine:
seq.status = SequenceStatus.FINISHED_STOPPED seq.status = SequenceStatus.FINISHED_STOPPED
return return
def _run_workers_in_batch(
self,
workers,
method: str,
*args,
**kwargs,
):
all_outputs = []
for worker in workers:
if self.parallel_config.worker_use_ray:
executor = partial(worker.execute_method.remote, method)
else:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
all_outputs.append(output)
if self.parallel_config.worker_use_ray:
all_outputs = ray.get(all_outputs)
return all_outputs
def _run_workers( def _run_workers(
self, self,
method: str, method: str,
*args, *args,
get_all_outputs: bool = False, driver_args: Optional[List[Any]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
max_concurrent_workers: Optional[int] = None, max_concurrent_workers: Optional[int] = None,
**kwargs, **kwargs,
) -> Any: ) -> Any:
"""Runs the given method on all workers.""" """Runs the given method on all workers."""
all_outputs = []
if max_concurrent_workers: if max_concurrent_workers:
work_groups = [ raise NotImplementedError(
self.workers[i:i + max_concurrent_workers] "max_concurrent_workers is not supported yet.")
for i in range(0, len(self.workers), max_concurrent_workers)
# Start the ray workers first.
ray_worker_outputs = [
worker.execute_method.remote(method, *args, **kwargs)
for worker in self.workers
] ]
else:
work_groups = [self.workers]
for workers in work_groups: if driver_args is None:
all_outputs.extend( driver_args = args
self._run_workers_in_batch(workers, method, *args, **kwargs)) if driver_kwargs is None:
driver_kwargs = kwargs
if get_all_outputs: # Start the driver worker after all the ray workers.
return all_outputs driver_worker_output = getattr(self.driver_worker,
method)(*driver_args, **driver_kwargs)
# Make sure all workers have the same results. # Get the results of the ray workers.
output = all_outputs[0] if self.workers:
for other_output in all_outputs[1:]: ray_worker_outputs = ray.get(ray_worker_outputs)
assert output == other_output
return output return [driver_worker_output] + ray_worker_outputs
from typing import Optional, Tuple, TYPE_CHECKING from typing import Optional, List, Tuple, TYPE_CHECKING
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_open_port, is_hip from vllm.utils import is_hip, set_cuda_visible_devices, get_ip
logger = init_logger(__name__) logger = init_logger(__name__)
try: try:
import ray import ray
from ray.air.util.torch_dist import TorchDistributedWorker
class RayWorkerVllm(TorchDistributedWorker): class RayWorkerVllm:
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be """Ray wrapper for vllm.worker.Worker, allowing Worker to be
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
...@@ -30,12 +29,22 @@ try: ...@@ -30,12 +29,22 @@ try:
executor = getattr(self, method) executor = getattr(self, method)
return executor(*args, **kwargs) return executor(*args, **kwargs)
def get_node_ip(self) -> str:
return get_ip()
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids
def set_cuda_visible_devices(self, device_ids) -> None:
set_cuda_visible_devices(device_ids)
except ImportError as e: except ImportError as e:
logger.warning(f"Failed to import Ray with {e!r}. " logger.warning(f"Failed to import Ray with {e!r}. "
"For distributed inference, please install Ray with " "For distributed inference, please install Ray with "
"`pip install ray pandas pyarrow`.") "`pip install ray pandas pyarrow`.")
ray = None ray = None
TorchDistributedWorker = None
RayWorkerVllm = None RayWorkerVllm = None
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -75,13 +84,11 @@ def initialize_cluster( ...@@ -75,13 +84,11 @@ def initialize_cluster(
ray.init(address=ray_address, ignore_reinit_error=True) ray.init(address=ray_address, ignore_reinit_error=True)
if not parallel_config.worker_use_ray: if not parallel_config.worker_use_ray:
# Initialize cluster locally. assert parallel_config.world_size == 1, (
port = get_open_port() "Ray is required if parallel_config.world_size > 1.")
# We need to setup the distributed init method to make sure return None
# the distributed megatron code (e.g., get world size) works correctly.
distributed_init_method = f"tcp://localhost:{port}"
return distributed_init_method, None
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group() current_placement_group = ray.util.get_current_placement_group()
if current_placement_group: if current_placement_group:
# We are in a placement group # We are in a placement group
...@@ -106,12 +113,12 @@ def initialize_cluster( ...@@ -106,12 +113,12 @@ def initialize_cluster(
"The number of required GPUs exceeds the total number of " "The number of required GPUs exceeds the total number of "
"available GPUs in the cluster.") "available GPUs in the cluster.")
# Create a new placement group # Create a new placement group
current_placement_group = ray.util.placement_group([{ placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
"GPU": 1 current_placement_group = ray.util.placement_group(
}] * parallel_config.world_size) placement_group_specs)
# Wait until PG is ready - this will block until all # Wait until PG is ready - this will block until all
# requested resources are available, and will timeout # requested resources are available, and will timeout
# if they cannot be provisioned. # if they cannot be provisioned.
ray.get(current_placement_group.ready(), timeout=1800) ray.get(current_placement_group.ready(), timeout=1800)
return None, current_placement_group return current_placement_group
...@@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingParams ...@@ -12,7 +12,6 @@ from vllm.sampling_params import SamplingParams
from vllm.utils import random_uuid from vllm.utils import random_uuid
TIMEOUT_KEEP_ALIVE = 5 # seconds. TIMEOUT_KEEP_ALIVE = 5 # seconds.
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds.
app = FastAPI() app = FastAPI()
engine = None engine = None
......
from typing import List, Optional from typing import Optional
import torch import torch
...@@ -16,28 +16,27 @@ class InputMetadata: ...@@ -16,28 +16,27 @@ class InputMetadata:
def __init__( def __init__(
self, self,
prompt_lens: List[int], is_prompt: bool,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
max_context_len: Optional[int], max_context_len: Optional[int],
context_lens: Optional[torch.Tensor], context_lens: Optional[torch.Tensor],
block_tables: Optional[torch.Tensor], block_tables: Optional[torch.Tensor],
use_cuda_graph: bool, use_cuda_graph: bool,
) -> None: ) -> None:
self.prompt_lens = prompt_lens self.is_prompt = is_prompt
self.max_context_len = max_context_len self.max_context_len = max_context_len
self.slot_mapping = slot_mapping self.slot_mapping = slot_mapping
self.context_lens = context_lens self.context_lens = context_lens
self.block_tables = block_tables self.block_tables = block_tables
self.use_cuda_graph = use_cuda_graph self.use_cuda_graph = use_cuda_graph
self.is_prompt = len(prompt_lens) > 0
# Set during the execution of the first attention op. # Set during the execution of the first attention op.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
self.attn_bias = None self.attn_bias = None
def __repr__(self) -> str: def __repr__(self) -> str:
return ("InputMetadata(" return ("InputMetadata("
f"prompt_lens={self.prompt_lens}, " f"is_prompt={self.is_prompt}, "
f"max_context_len={self.max_context_len}, " f"max_context_len={self.max_context_len}, "
f"slot_mapping={self.slot_mapping}, " f"slot_mapping={self.slot_mapping}, "
f"context_lens={self.context_lens}, " f"context_lens={self.context_lens}, "
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.model_executor.parallel_utils.communication_op import ( from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_gather) tensor_model_parallel_gather)
from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
...@@ -37,7 +37,7 @@ class Sampler(nn.Module): ...@@ -37,7 +37,7 @@ class Sampler(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
...@@ -45,6 +45,14 @@ class Sampler(nn.Module): ...@@ -45,6 +45,14 @@ class Sampler(nn.Module):
logits = _get_logits(hidden_states, embedding, embedding_bias, logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size) self.vocab_size)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if not sampling_metadata.perform_sampling:
return None
assert logits is not None
_, vocab_size = logits.shape _, vocab_size = logits.shape
# Apply logits processors (if any). # Apply logits processors (if any).
...@@ -92,13 +100,14 @@ class Sampler(nn.Module): ...@@ -92,13 +100,14 @@ class Sampler(nn.Module):
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor], embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> torch.Tensor: vocab_size: int) -> Optional[torch.Tensor]:
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t()) logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None: if embedding_bias is not None:
logits += embedding_bias logits += embedding_bias
logits = tensor_model_parallel_all_gather(logits) logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any). # Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :vocab_size] logits = logits[:, :vocab_size]
return logits return logits
...@@ -112,27 +121,6 @@ def _prune_hidden_states( ...@@ -112,27 +121,6 @@ def _prune_hidden_states(
sampling_metadata.selected_token_indices) sampling_metadata.selected_token_indices)
def _get_prompt_and_output_tokens(
sampling_metadata: SamplingMetadata,
) -> Tuple[List[List[int]], List[List[int]]]:
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
prompt_len = sampling_metadata.prompt_lens[i]
prompt_tokens.extend([] for _ in range(prompt_len - 1))
output_tokens.extend([] for _ in range(prompt_len - 1))
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
prompt_tokens.append(seq_data.prompt_token_ids)
output_tokens.append(seq_data.output_token_ids)
return prompt_tokens, output_tokens
def _get_bin_counts_and_mask( def _get_bin_counts_and_mask(
tokens: torch.Tensor, tokens: torch.Tensor,
vocab_size: int, vocab_size: int,
......
...@@ -298,7 +298,7 @@ class AquilaForCausalLM(nn.Module): ...@@ -298,7 +298,7 @@ class AquilaForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -313,7 +313,7 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -313,7 +313,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -290,7 +290,7 @@ class BloomForCausalLM(nn.Module): ...@@ -290,7 +290,7 @@ class BloomForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -349,7 +349,7 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -349,7 +349,7 @@ class ChatGLMForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -394,7 +394,7 @@ class FalconForCausalLM(nn.Module): ...@@ -394,7 +394,7 @@ class FalconForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -235,7 +235,7 @@ class GPT2LMHeadModel(nn.Module): ...@@ -235,7 +235,7 @@ class GPT2LMHeadModel(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -254,7 +254,7 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -254,7 +254,7 @@ class GPTBigCodeForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head_weight, hidden_states, next_tokens = self.sampler(self.lm_head_weight, hidden_states,
sampling_metadata) sampling_metadata)
return next_tokens return next_tokens
......
...@@ -240,7 +240,7 @@ class GPTJForCausalLM(nn.Module): ...@@ -240,7 +240,7 @@ class GPTJForCausalLM(nn.Module):
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> SamplerOutput: ) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states, next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata, self.lm_head.bias) sampling_metadata, self.lm_head.bias)
return next_tokens return next_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment