Commit ead94d93 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents fcffb7c8 f780504d
...@@ -6,12 +6,12 @@ import torch ...@@ -6,12 +6,12 @@ import torch
from vllm._C import cache_ops from vllm._C import cache_ops
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [83] # Arbitrary values for testing NUM_TOKENS = [42] # Arbitrary values for testing
NUM_LAYERS = [1] # Arbitrary values for testing NUM_LAYERS = [1] # Arbitrary values for testing
NUM_HEADS = [8] # Arbitrary values for testing NUM_HEADS = [8] # Arbitrary values for testing
HEAD_SIZES = [64, 80, 96, 112, 128, 256] HEAD_SIZES = [64, 80, 96, 112, 128, 256]
BLOCK_SIZES = [8, 16, 32] BLOCK_SIZES = [8, 16, 32]
NUM_BLOCKS = [1024, 36000] # Arbitrary values for testing NUM_BLOCKS = [1024, 3600] # Arbitrary values for testing
NUM_MAPPINGS = [256] # Arbitrary values for testing NUM_MAPPINGS = [256] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)]
......
...@@ -30,6 +30,7 @@ def test_get_prompt_logprobs( ...@@ -30,6 +30,7 @@ def test_get_prompt_logprobs(
temperature=0.0) temperature=0.0)
vllm_results = vllm_model.model.generate( vllm_results = vllm_model.model.generate(
example_prompts, sampling_params=vllm_sampling_params) example_prompts, sampling_params=vllm_sampling_params)
del vllm_model
# Test whether logprobs are included in the results. # Test whether logprobs are included in the results.
for result in vllm_results: for result in vllm_results:
......
"""Tests for rejection sampling."""
import pytest
from typing import List, Tuple
import torch
import torch.nn.functional as F
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
def mock_causal_accepted_tensor(
k: int, last_accepted_indices: torch.Tensor) -> torch.Tensor:
"""Generate an "accepted" tensor which should yield causally-accepted tokens
up to last accepted indices.
Tokens after last_accepted_indices+1 may also be accepted, although they
will not be causally accepted.
"""
batch_size = last_accepted_indices.shape[0]
accepted = (torch.arange(k).expand(batch_size, k) <=
last_accepted_indices.unsqueeze(-1).broadcast_to(
batch_size, k)).to(device="cuda")
# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates = (
torch.arange(k).expand(batch_size, k) >
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
sprinkle = torch.rand(batch_size, k, device="cuda") > 0.5
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
return accepted
@pytest.mark.parametrize("seed", list(range(10)))
@pytest.mark.parametrize(
"which_tokens_accepted",
["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"])
@torch.inference_mode()
def test_correct_output_format(which_tokens_accepted: str, seed: int):
"""Verify the output has correct format given predetermined accepted matrix.
"""
set_random_seed(seed)
batch_size = 10
k = 5
vocab_size = 3000
if which_tokens_accepted == "all_tokens_accepted":
accepted = mock_causal_accepted_tensor(
k, -1 + k * torch.ones((batch_size, ), dtype=torch.long))
elif which_tokens_accepted == "no_tokens_accepted":
accepted = mock_causal_accepted_tensor(
k, -torch.ones((batch_size, ), dtype=torch.long))
elif which_tokens_accepted == "some_tokens_accepted":
last_accepted_indices = torch.randint(low=-1,
high=k,
size=(batch_size, ))
accepted = mock_causal_accepted_tensor(k, last_accepted_indices)
else:
raise AssertionError()
recovered_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device="cuda")
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
if which_tokens_accepted == "all_tokens_accepted":
# Expect all tokens to be equal to draft tokens.
assert torch.equal(output_token_ids[:, :-1], draft_token_ids)
# Expect all bonus tokens to be included.
assert torch.equal(output_token_ids[:, -1:], bonus_token_ids)
elif which_tokens_accepted == "no_tokens_accepted":
# Expect first token to be equal to recovered tokens.
assert torch.equal(output_token_ids[:, 0], recovered_token_ids[:, 0])
# Expect everything else to be -1.
assert torch.equal(output_token_ids[:, 1:],
torch.ones_like(output_token_ids[:, 1:]) * -1)
elif which_tokens_accepted == "some_tokens_accepted":
recovered_plus_bonus = torch.cat(
(recovered_token_ids, bonus_token_ids), dim=-1)
# Assert first rejected token is a recovered token or bonus token.
assert torch.equal(
recovered_plus_bonus[torch.arange(0, batch_size),
last_accepted_indices + 1],
output_token_ids[torch.arange(0, batch_size),
last_accepted_indices + 1])
# Assert every subsequent token is -1.
subsequent_mask = torch.arange(0, k + 1).expand(
batch_size, k + 1) >= (last_accepted_indices + 2).unsqueeze(-1)
assert torch.all(output_token_ids[subsequent_mask] == -1)
@pytest.mark.parametrize("k", list(range(1, 6)))
@pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", list(range(1, 32)))
@torch.inference_mode()
def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int):
rejection_sampler = RejectionSampler()
rejection_sampler.init_gpu_tensors(rank=0)
draft_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
target_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device="cuda")
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
@pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"])
@pytest.mark.parametrize("which_token_ids",
["bonus_token_ids", "draft_token_ids"])
@torch.inference_mode()
def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
which_token_ids: str):
k = 3
batch_size = 5
vocab_size = 30_000
rejection_sampler = RejectionSampler(strict_mode=True)
rejection_sampler.init_gpu_tensors(rank=0)
draft_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
target_probs = torch.rand(batch_size,
k,
vocab_size,
dtype=torch.float32,
device="cuda")
bonus_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, 1),
dtype=torch.int64,
device="cuda")
draft_token_ids = torch.randint(low=0,
high=vocab_size,
size=(batch_size, k),
dtype=torch.int64,
device="cuda")
oob_token_ids = None
if which_token_ids == "bonus_token_ids":
oob_token_ids = bonus_token_ids
elif which_token_ids == "draft_token_ids":
oob_token_ids = draft_token_ids
else:
raise AssertionError()
if above_or_below_vocab_range == "above":
rogue_token_id = vocab_size + 1
elif above_or_below_vocab_range == "below":
rogue_token_id = -1
else:
raise AssertionError()
oob_token_ids[0][0] = rogue_token_id
with pytest.raises(AssertionError):
rejection_sampler(target_probs, bonus_token_ids, draft_probs,
draft_token_ids)
@pytest.mark.parametrize("draft_and_target_probs_equal", [True, False])
@pytest.mark.parametrize("seed", list(range(5)))
@torch.inference_mode()
def test_rejection_sampling_approximates_target_distribution(
seed: int, draft_and_target_probs_equal: bool):
"""Verify rejection sampling approximates target distribution,
despite sampling from a potentially distinct draft distribution.
This is done by first creating a random target probability
distribution and a random draft probability distribution. We then
sample token ids from the rejection sampler using these draft
and target distributions. The samples are used to estimate
the output probability distribution, which we expect to approximate
the target distribution.
A basic distance metric is used to determine similarity between
distributions.
We expect that as we increase the number of samples,
the distance between the observed distribution and the target
distribution decreases. To measure this, we compare the distance
of the observed distribution against both the target distribution
and a uniform random distribution. We expect the distance between
the observed distribution and the target distribution to improve
much more than the distance improvement between the observed
distribution and the random distribution.
When draft_and_target_probs_equal=True, the draft and target
probabilities are exactly equal. Rejection sampling should
still work without any NaNs or exceptions.
"""
set_random_seed(seed)
helper = _CorrectnessTestHelper(
vocab_size=10,
rejection_sampler=RejectionSampler(),
)
draft_probs, target_probs, reference_probs = helper.generate_probs_for_test(
draft_and_target_probs_equal)
sample_sizes = [10, 100, 1_000, 10_000, 100_000]
distance_wrt_reference = []
distance_wrt_target = []
for num_samples in sample_sizes:
(reference_vs_rejsample_dist,
target_vs_rejsample_dist) = helper.run_and_compare_distributions(
draft_probs,
target_probs,
reference_probs,
num_samples,
)
distance_wrt_reference.append(reference_vs_rejsample_dist)
distance_wrt_target.append(target_vs_rejsample_dist)
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
print(f"{num_samples=} {target_vs_rejsample_dist=:.05f} "
f"{reference_vs_rejsample_dist=:.05f}")
print(f"{num_samples=} {relative_change_in_distance_wrt_target=:.02f} "
f"{relative_change_in_distance_wrt_reference=:.02f}")
relative_change_in_distance_wrt_target = get_ratio_first_to_last(
distance_wrt_target)
relative_change_in_distance_wrt_reference = get_ratio_first_to_last(
distance_wrt_reference)
expected_improvement_multiplier = 20
assert (relative_change_in_distance_wrt_target >
relative_change_in_distance_wrt_reference *
expected_improvement_multiplier)
def get_ratio_first_to_last(elements: List[float]) -> float:
return elements[0] / elements[-1]
class _CorrectnessTestHelper:
"""Class that packages together logic required for the unit-level
rejection sampling correctness test.
"""
def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
self.rejection_sampler = rejection_sampler
self.vocab_size = vocab_size
self.vocab_range = (0, vocab_size)
self.rejection_sampler.init_gpu_tensors(rank=0)
# Keep test simple, use k=1
self.k = 1
# Bonus tokens not used, but rejection sampler requires
# correct shape.
self.num_bonus_tokens = 1
def generate_probs_for_test(
self, draft_and_target_probs_equal: bool
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
draft_probs, target_probs = [
F.softmax(
torch.rand(self.vocab_size, dtype=torch.float32),
dim=-1,
) for _ in range(2)
]
num_reference_probs = 100
reference_probs = F.softmax(
torch.rand(num_reference_probs,
self.vocab_size,
dtype=torch.float32),
dim=-1,
)
if draft_and_target_probs_equal:
target_probs = draft_probs.clone()
return draft_probs, target_probs, reference_probs
def run_and_compare_distributions(self, draft_probs: torch.Tensor,
target_probs: torch.Tensor,
reference_probs: torch.Tensor,
num_samples: int) -> Tuple[float, float]:
# Sample using rejection sampling.
rej_sample_probs = self._estimate_rejection_sampling_pdf(
draft_probs, target_probs, num_samples)
# Average distance from reference probs.
reference_vs_rejsample_dist = torch.dist(
reference_probs,
rej_sample_probs).item() / reference_probs.shape[0]
target_vs_rejsample_dist = torch.dist(target_probs,
rej_sample_probs).item()
return reference_vs_rejsample_dist, target_vs_rejsample_dist
def _estimate_rejection_sampling_pdf(
self,
draft_probs: torch.Tensor,
target_probs: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
# Repeat draft probs num_samples times.
draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat(
num_samples, 1, 1)
# Repeat target probs num_samples * k times.
# Rejection sampler requires bonus token probs, but they aren't used.
target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat(
num_samples, self.k, 1)
# Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=1,
replacement=True).reshape(
num_samples, self.k)
# Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, self.num_bonus_tokens),
dtype=torch.int64,
device="cuda").repeat(num_samples, 1)
# Get output tokens via rejection sampling.
output_token_ids = self.rejection_sampler(target_probs.to("cuda"),
bonus_token_ids.to("cuda"),
draft_probs.to("cuda"),
draft_token_ids.to("cuda"))
# Remove bonus tokens
output_token_ids = output_token_ids[:, :-1].flatten()
# Estimate probability density function
hist = torch.histogram(output_token_ids.to(dtype=torch.float,
device="cpu"),
bins=self.vocab_size,
range=self.vocab_range,
density=True)
return hist.hist
...@@ -4,6 +4,7 @@ from unittest.mock import patch ...@@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest import pytest
import torch import torch
from transformers import GenerationConfig, GenerationMixin
from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed
...@@ -74,6 +75,8 @@ def test_sampler_all_greedy(seed: int): ...@@ -74,6 +75,8 @@ def test_sampler_all_greedy(seed: int):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token == expected[i].item() assert nth_output.output_token == expected[i].item()
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_random(seed: int): def test_sampler_all_random(seed: int):
...@@ -110,6 +113,8 @@ def test_sampler_all_random(seed: int): ...@@ -110,6 +113,8 @@ def test_sampler_all_random(seed: int):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token == i assert nth_output.output_token == i
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_all_beam(seed: int): def test_sampler_all_beam(seed: int):
...@@ -143,6 +148,7 @@ def test_sampler_all_beam(seed: int): ...@@ -143,6 +148,7 @@ def test_sampler_all_beam(seed: int):
# the outputs are expected - in other words, this just tests # the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler # whether there are no exceptions in the sampler
# when handling an all-beam search case. # when handling an all-beam search case.
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
...@@ -197,6 +203,8 @@ def test_sampler_mixed(seed: int): ...@@ -197,6 +203,8 @@ def test_sampler_mixed(seed: int):
for nth_output in sequence_output.samples: for nth_output in sequence_output.samples:
assert nth_output.output_token in expected_tokens assert nth_output.output_token in expected_tokens
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_logits_processors(seed: int): def test_sampler_logits_processors(seed: int):
...@@ -233,3 +241,69 @@ def test_sampler_logits_processors(seed: int): ...@@ -233,3 +241,69 @@ def test_sampler_logits_processors(seed: int):
for _, sequence_output in enumerate(sampler_output): for _, sequence_output in enumerate(sampler_output):
for idx, nth_output in enumerate(sequence_output.samples): for idx, nth_output in enumerate(sequence_output.samples):
assert nth_output.output_token == idx assert nth_output.output_token == idx
del model_runner
@pytest.mark.parametrize("seed", RANDOM_SEEDS)
def test_sampler_top_k_top_p(seed: int):
set_random_seed(seed)
batch_size = random.randint(1, 256)
top_k = random.randint(100, 500)
top_p = random.random() * 0.1
vocab_size = 32000
input_tensor = torch.rand((batch_size, 1024),
device="cuda",
dtype=torch.float16)
fake_logits = torch.normal(0,
5,
size=(batch_size, vocab_size),
device=input_tensor.device,
dtype=input_tensor.dtype)
sampler = MockLogitsSampler(32000, fake_logits)
model_runner = ModelRunner(None, None, None)
generation_model = GenerationMixin()
generation_config = GenerationConfig(top_k=top_k,
top_p=top_p,
do_sample=True)
warpers = generation_model._get_logits_warper(generation_config)
assert len(warpers) == 2 # top_p and top_k
seq_group_metadata_list = []
prompt_lens = []
for i in range(batch_size):
seq_group_metadata_list.append(
SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
seq_data={0: SequenceData([1, 2, 3])},
sampling_params=SamplingParams(
temperature=1,
top_k=top_k,
top_p=top_p,
),
block_tables={0: [1]},
))
prompt_lens.append(seq_group_metadata_list[-1].seq_data[0].get_len())
sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list,
prompt_lens)
sample_probs = None
def mock_sample(probs, logprobs, sampling_metadata):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
with patch("vllm.model_executor.layers.sampler._sample", mock_sample):
sampler(embedding=None,
hidden_states=input_tensor,
sampling_metadata=sampling_metadata)
hf_probs = warpers(torch.zeros_like(fake_logits), fake_logits.clone())
hf_probs = torch.softmax(hf_probs, dim=-1, dtype=torch.float)
assert torch.allclose(hf_probs, sample_probs, atol=1e-5)
assert torch.equal(hf_probs.eq(0), sample_probs.eq(0))
del model_runner
from typing import List from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup from vllm.sequence import SequenceGroup
...@@ -15,13 +16,14 @@ class Policy: ...@@ -15,13 +16,14 @@ class Policy:
def sort_by_priority( def sort_by_priority(
self, self,
now: float, now: float,
seq_groups: List[SequenceGroup], seq_groups: Deque[SequenceGroup],
) -> List[SequenceGroup]: ) -> Deque[SequenceGroup]:
return sorted( return deque(
sorted(
seq_groups, seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group), key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True, reverse=True,
) ))
class FCFS(Policy): class FCFS(Policy):
......
from collections import deque
import enum import enum
import time import time
from typing import Dict, Iterable, List, Optional, Tuple, Union from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union
from vllm.config import CacheConfig, SchedulerConfig from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.block_manager import AllocStatus, BlockSpaceManager
...@@ -29,7 +30,7 @@ class SchedulerOutputs: ...@@ -29,7 +30,7 @@ class SchedulerOutputs:
def __init__( def __init__(
self, self,
scheduled_seq_groups: List[SequenceGroup], scheduled_seq_groups: Iterable[SequenceGroup],
prompt_run: bool, prompt_run: bool,
num_batched_tokens: int, num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int], blocks_to_swap_in: Dict[int, int],
...@@ -75,38 +76,52 @@ class Scheduler: ...@@ -75,38 +76,52 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window) sliding_window=self.cache_config.sliding_window)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state. # Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = [] self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state. # Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = [] self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state. # Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = [] self.swapped: Deque[SequenceGroup] = deque()
def add_seq_group(self, seq_group: SequenceGroup) -> None: def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue. # Add sequence groups to the waiting queue.
self.waiting.append(seq_group) self.waiting.append(seq_group)
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
"""Aborts a sequence group with the given ID.
Check if the sequence group with the given ID
is present in any of the state queue.
If present, remove the sequence group from the state queue.
Also, if any of the sequences in the sequence group is not finished,
free the sequence with status `FINISHED_ABORTED`.
Otherwise, do nothing.
Args:
request_id: The ID(s) of the sequence group to abort.
"""
if isinstance(request_id, str): if isinstance(request_id, str):
request_id = (request_id, ) request_id = (request_id, )
request_ids = set(request_id) request_ids = set(request_id)
for state_queue in [self.waiting, self.running, self.swapped]: for state_queue in [self.waiting, self.running, self.swapped]:
# We need to reverse the list as we are removing elements aborted_groups = []
# from it as we iterate over it. If we don't do it, for seq_group in state_queue:
# indices will get messed up and we will skip over elements. if not request_ids:
for seq_group in reversed(state_queue): # Using 'break' here may add two extra iterations,
# but is acceptable to reduce complexity .
break
if seq_group.request_id in request_ids: if seq_group.request_id in request_ids:
# Appending aborted group into pending list.
aborted_groups.append(seq_group)
request_ids.remove(seq_group.request_id)
for aborted_group in aborted_groups:
# Remove the sequence group from the state queue. # Remove the sequence group from the state queue.
state_queue.remove(seq_group) state_queue.remove(aborted_group)
for seq in seq_group.get_seqs(): for seq in seq_group.get_seqs():
if seq.is_finished(): if seq.is_finished():
continue continue
seq.status = SequenceStatus.FINISHED_ABORTED seq.status = SequenceStatus.FINISHED_ABORTED
self.free_seq(seq) self.free_seq(seq)
request_ids.remove(seq_group.request_id)
if not request_ids:
return
def has_unfinished_seqs(self) -> bool: def has_unfinished_seqs(self) -> bool:
return self.waiting or self.running or self.swapped return self.waiting or self.running or self.swapped
...@@ -152,7 +167,7 @@ class Scheduler: ...@@ -152,7 +167,7 @@ class Scheduler:
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
self.waiting.pop(0) self.waiting.popleft()
continue continue
# If the sequence group cannot be allocated, stop. # If the sequence group cannot be allocated, stop.
...@@ -166,7 +181,7 @@ class Scheduler: ...@@ -166,7 +181,7 @@ class Scheduler:
for seq in waiting_seqs: for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group) ignored_seq_groups.append(seq_group)
self.waiting.pop(0) self.waiting.popleft()
continue continue
# If the number of batched tokens exceeds the limit, stop. # If the number of batched tokens exceeds the limit, stop.
...@@ -188,7 +203,7 @@ class Scheduler: ...@@ -188,7 +203,7 @@ class Scheduler:
break break
seq_lens = new_seq_lens seq_lens = new_seq_lens
seq_group = self.waiting.pop(0) seq_group = self.waiting.popleft()
self._allocate(seq_group) self._allocate(seq_group)
self.running.append(seq_group) self.running.append(seq_group)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
...@@ -214,14 +229,14 @@ class Scheduler: ...@@ -214,14 +229,14 @@ class Scheduler:
self.running = self.policy.sort_by_priority(now, self.running) self.running = self.policy.sort_by_priority(now, self.running)
# Reserve new token slots for the running sequence groups. # Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = [] running: Deque[SequenceGroup] = deque()
preempted: List[SequenceGroup] = [] preempted: List[SequenceGroup] = []
while self.running: while self.running:
seq_group = self.running.pop(0) seq_group = self.running.popleft()
while not self.block_manager.can_append_slot(seq_group): while not self.block_manager.can_append_slot(seq_group):
if self.running: if self.running:
# Preempt the lowest-priority sequence groups. # Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1) victim_seq_group = self.running.pop()
self._preempt(victim_seq_group, blocks_to_swap_out) self._preempt(victim_seq_group, blocks_to_swap_out)
preempted.append(victim_seq_group) preempted.append(victim_seq_group)
else: else:
...@@ -255,7 +270,7 @@ class Scheduler: ...@@ -255,7 +270,7 @@ class Scheduler:
self.scheduler_config.max_num_seqs): self.scheduler_config.max_num_seqs):
break break
seq_group = self.swapped.pop(0) seq_group = self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in) self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy) self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs num_curr_seqs += num_new_seqs
...@@ -376,7 +391,7 @@ class Scheduler: ...@@ -376,7 +391,7 @@ class Scheduler:
self.block_manager.free(seq) self.block_manager.free(seq)
# NOTE: For FCFS, we insert the preempted sequence group to the front # NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue. # of the waiting queue.
self.waiting.insert(0, seq_group) self.waiting.appendleft(seq_group)
def _preempt_by_swap( def _preempt_by_swap(
self, self,
......
...@@ -253,7 +253,8 @@ class AsyncLLMEngine: ...@@ -253,7 +253,8 @@ class AsyncLLMEngine:
log_requests: Whether to log the requests. log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call. will be automatically started in the generate call.
*args, *kwargs: Arguments for LLMEngine. *args: Arguments for LLMEngine.
*kwargs: Arguments for LLMEngine.
""" """
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
...@@ -428,6 +429,49 @@ class AsyncLLMEngine: ...@@ -428,6 +429,49 @@ class AsyncLLMEngine:
Yields: Yields:
The output `RequestOutput` objects from the LLMEngine for the The output `RequestOutput` objects from the LLMEngine for the
request. request.
Details:
- If the engine is not running, start the background loop,
which iteratively invokes
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
to process the waiting requests.
- Add the request to the engine's `RequestTracker`.
On the next background loop, this request will be sent to
the underlying engine.
Also, a corresponding `AsyncStream` will be created.
- Wait for the request outputs from `AsyncStream` and yield them.
Example:
>>> # Please refer to entrypoints/api_server.py for
>>> # the complete example.
>>>
>>> # initialize the engine and the example input
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
>>> example_input = {
>>> "prompt": "What is LLM?",
>>> "stream": False, # assume the non-streaming case
>>> "temperature": 0.0,
>>> "request_id": 0,
>>> }
>>>
>>> # start the generation
>>> results_generator = engine.generate(
>>> example_input["prompt"],
>>> SamplingParams(temperature=example_input["temperature"]),
>>> example_input["request_id"])
>>>
>>> # get the results
>>> final_output = None
>>> async for request_output in results_generator:
>>> if await request.is_disconnected():
>>> # Abort the request if the client disconnects.
>>> await engine.abort(request_id)
>>> # Return or raise an error
>>> ...
>>> final_output = request_output
>>>
>>> # Process and return the final output
>>> ...
""" """
# Preprocess the request. # Preprocess the request.
# This should not be used for logging, as it is monotonic time. # This should not be used for logging, as it is monotonic time.
...@@ -506,3 +550,9 @@ class AsyncLLMEngine: ...@@ -506,3 +550,9 @@ class AsyncLLMEngine:
max_log_len=engine_args.max_log_len, max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop) start_engine_loop=start_engine_loop)
return engine return engine
async def do_log_stats(self) -> None:
if self.engine_use_ray:
await self.engine.do_log_stats.remote()
else:
self.engine.do_log_stats()
...@@ -257,7 +257,26 @@ class LLMEngine: ...@@ -257,7 +257,26 @@ class LLMEngine:
self.cache_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config)
def _init_cache(self) -> None: def _init_cache(self) -> None:
"""Profiles the memory usage and initializes the KV cache.""" """Profiles the memory usage and initializes the KV cache.
The engine will first conduct a profiling of the existing memory usage.
Then, it calculate the maximum possible number of GPU and CPU blocks
that can be allocated with the remaining free memory.
More details can be found in the
:meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method
from class :class:`~vllm.worker.Worker`.
Afterwards, as there may be multiple workers,
we take the minimum number of blocks across all workers
to ensure this can be applied to all of them.
Finally, the engine will initialize the KV cache
with the calculated number of blocks.
.. tip::
You may limit the usage of GPU memory
by adjusting the `gpu_memory_utilization` parameters.
"""
# Get the maximum number of blocks that can be allocated on GPU and CPU. # Get the maximum number of blocks that can be allocated on GPU and CPU.
num_blocks = self._run_workers( num_blocks = self._run_workers(
"profile_num_available_blocks", "profile_num_available_blocks",
...@@ -334,6 +353,30 @@ class LLMEngine: ...@@ -334,6 +353,30 @@ class LLMEngine:
use the tokenizer to convert the prompts to token IDs. use the tokenizer to convert the prompts to token IDs.
arrival_time: The arrival time of the request. If None, we use arrival_time: The arrival time of the request. If None, we use
the current monotonic time. the current monotonic time.
Details:
- Set arrival_time to the current time if it is None.
- Set prompt_token_ids to the encoded prompt if it is None.
- Create `best_of` number of :class:`~vllm.Sequence` objects.
- Create a :class:`~vllm.SequenceGroup` object
from the list of :class:`~vllm.Sequence`.
- Add the :class:`~vllm.SequenceGroup` object to the scheduler.
Example:
>>> # initialize engine
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> # set request arguments
>>> example_prompt = "Who is the president of the United States?"
>>> sampling_params = SamplingParams(temperature=0.0)
>>> request_id = 0
>>>
>>> # add the request to the engine
>>> engine.add_request(
>>> str(request_id),
>>> example_prompt,
>>> SamplingParams(temperature=0.0))
>>> # continue the request processing
>>> ...
""" """
if arrival_time is None: if arrival_time is None:
arrival_time = time.monotonic() arrival_time = time.monotonic()
...@@ -358,6 +401,17 @@ class LLMEngine: ...@@ -358,6 +401,17 @@ class LLMEngine:
Args: Args:
request_id: The ID(s) of the request to abort. request_id: The ID(s) of the request to abort.
Details:
- Refer to the
:meth:`~vllm.core.scheduler.Scheduler.abort_seq_group`
from class :class:`~vllm.core.scheduler.Scheduler`.
Example:
>>> # initialize engine and add a request with request_id
>>> request_id = str(0)
>>> # abort the request
>>> engine.abort_request(request_id)
""" """
self.scheduler.abort_seq_group(request_id) self.scheduler.abort_seq_group(request_id)
...@@ -601,8 +655,10 @@ class LLMEngine: ...@@ -601,8 +655,10 @@ class LLMEngine:
# Create the outputs. # Create the outputs.
request_outputs: List[RequestOutput] = [] request_outputs: List[RequestOutput] = []
for seq_group in (scheduled_seq_groups + for seq_group in scheduled_seq_groups:
scheduler_outputs.ignored_seq_groups): request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group) request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output) request_outputs.append(request_output)
...@@ -615,11 +671,53 @@ class LLMEngine: ...@@ -615,11 +671,53 @@ class LLMEngine:
def step(self) -> List[RequestOutput]: def step(self) -> List[RequestOutput]:
"""Performs one decoding iteration and returns newly generated results. """Performs one decoding iteration and returns newly generated results.
This function performs one decoding iteration of the engine. It first .. figure:: https://i.imgur.com/sv2HssD.png
schedules the sequences to be executed in the next iteration and the :alt: Overview of the step function
token blocks to be swapped in/out/copy. Then, it executes the model :align: center
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results. Overview of the step function.
Details:
- Step 1: Schedules the sequences to be executed in the next
iteration and the token blocks to be swapped in/out/copy.
- Depending on the scheduling policy,
sequences may be `preempted/reordered`.
- A Sequence Group (SG) refer to a group of sequences
that are generated from the same prompt.
- Step 2: Calls the workers to execute the model.
- Step 3: Processes the model output. This mainly includes:
- Decodes the relevant outputs.
- Updates the scheduled sequence groups with model outputs
based on its `sampling parameters` (`use_beam_search` or not).
- Frees the finished sequence groups.
- Finally, it creates and returns the newly generated results.
Example:
>>> # Please see the example/ folder for more detailed examples.
>>>
>>> # initialize engine and request arguments
>>> engine = LLMEngine.from_engine_args(engine_args)
>>> example_inputs = [(0, "What is LLM?",
>>> SamplingParams(temperature=0.0))]
>>>
>>> # Start the engine with an event loop
>>> while True:
>>> if example_inputs:
>>> req_id, prompt, sampling_params = example_inputs.pop(0)
>>> engine.add_request(str(req_id), prompt, sampling_params)
>>>
>>> # continue the request processing
>>> request_outputs = engine.step()
>>> for request_output in request_outputs:
>>> if request_output.finished:
>>> # return or show the request output
>>>
>>> if not (engine.has_unfinished_requests() or example_inputs):
>>> break
""" """
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
...@@ -641,6 +739,9 @@ class LLMEngine: ...@@ -641,6 +739,9 @@ class LLMEngine:
return self._process_model_outputs(output, scheduler_outputs) return self._process_model_outputs(output, scheduler_outputs)
def do_log_stats(self) -> None:
self._log_system_stats(False, 0)
def _log_system_stats( def _log_system_stats(
self, self,
prompt_run: bool, prompt_run: bool,
......
...@@ -55,7 +55,7 @@ def initialize_cluster( ...@@ -55,7 +55,7 @@ def initialize_cluster(
parallel_config: ParallelConfig, parallel_config: ParallelConfig,
engine_use_ray: bool = False, engine_use_ray: bool = False,
ray_address: Optional[str] = None, ray_address: Optional[str] = None,
) -> Tuple[str, Optional["PlacementGroup"]]: ) -> Optional["PlacementGroup"]:
"""Initialize the distributed cluster probably with Ray. """Initialize the distributed cluster probably with Ray.
Args: Args:
......
...@@ -74,12 +74,18 @@ if __name__ == "__main__": ...@@ -74,12 +74,18 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=8000) parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--ssl-keyfile", type=str, default=None) parser.add_argument("--ssl-keyfile", type=str, default=None)
parser.add_argument("--ssl-certfile", type=str, default=None) parser.add_argument("--ssl-certfile", type=str, default=None)
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args() args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args) engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args) engine = AsyncLLMEngine.from_engine_args(engine_args)
app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
......
...@@ -6,6 +6,7 @@ import asyncio ...@@ -6,6 +6,7 @@ import asyncio
import codecs import codecs
import json import json
import time import time
from contextlib import asynccontextmanager
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union from typing import AsyncGenerator, Dict, List, Optional, Tuple, Union
...@@ -38,11 +39,28 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds ...@@ -38,11 +39,28 @@ TIMEOUT_KEEP_ALIVE = 5 # seconds
logger = init_logger(__name__) logger = init_logger(__name__)
served_model = None served_model = None
app = fastapi.FastAPI() engine_args = None
engine = None engine = None
response_role = None response_role = None
@asynccontextmanager
async def lifespan(app: fastapi.FastAPI):
async def _force_log():
while True:
await asyncio.sleep(10)
await engine.do_log_stats()
if not engine_args.disable_log_stats:
asyncio.create_task(_force_log())
yield
app = fastapi.FastAPI(lifespan=lifespan)
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.") description="vLLM OpenAI-Compatible RESTful API server.")
...@@ -88,6 +106,11 @@ def parse_args(): ...@@ -88,6 +106,11 @@ def parse_args():
type=str, type=str,
default=None, default=None,
help="The file path to the SSL cert file") help="The file path to the SSL cert file")
parser.add_argument(
"--root-path",
type=str,
default=None,
help="FastAPI root_path when app is behind a path based routing proxy")
parser = AsyncEngineArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser)
return parser.parse_args() return parser.parse_args()
...@@ -748,6 +771,7 @@ if __name__ == "__main__": ...@@ -748,6 +771,7 @@ if __name__ == "__main__":
# Register labels for metrics # Register labels for metrics
add_global_metrics_labels(model_name=engine_args.model) add_global_metrics_labels(model_name=engine_args.model)
app.root_path = args.root_path
uvicorn.run(app, uvicorn.run(app,
host=args.host, host=args.host,
port=args.port, port=args.port,
......
...@@ -156,7 +156,6 @@ class PagedAttention(nn.Module): ...@@ -156,7 +156,6 @@ class PagedAttention(nn.Module):
output = out.view_as(query) output = out.view_as(query)
else: else:
# Decoding run. # Decoding run.
if key_cache is not None and value_cache is not None:
output = _paged_attention( output = _paged_attention(
query, query,
key_cache, key_cache,
...@@ -166,10 +165,6 @@ class PagedAttention(nn.Module): ...@@ -166,10 +165,6 @@ class PagedAttention(nn.Module):
self.scale, self.scale,
self.alibi_slopes, self.alibi_slopes,
) )
else:
# This happens during the initial memory profiling run for
# CUDA graphs.
output = torch.zeros_like(query)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(batch_size, seq_len, hidden_size) return output.view(batch_size, seq_len, hidden_size)
......
...@@ -423,6 +423,9 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -423,6 +423,9 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q":
shard_id = tp_rank
else:
shard_id = tp_rank // self.num_kv_head_replicas shard_id = tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size start_idx = shard_id * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
......
from typing import Tuple, Optional
from functools import cached_property
import torch
import torch.nn as nn
import torch.jit
class RejectionSampler(nn.Module):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def __init__(self, strict_mode: bool = False):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self.probs_dtype = torch.float32
self.token_id_dtype = torch.int64
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
bonus_token_ids,
draft_token_ids)
accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
target_probs,
draft_probs,
draft_token_ids,
)
output_token_ids = self._create_output(
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
return output_token_ids
def _batch_modified_rejection_sampling(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size, k, vocab_size = draft_probs.shape
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
return accepted, recovered_token_ids
def _get_accepted(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
:math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
same conditional probability according to the draft model, the token
is accepted with probability:
.. math::
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size, k, _ = draft_probs.shape
batch_indices = torch.arange(batch_size,
device=target_probs.device)[:, None]
probs_indicies = torch.arange(k, device=target_probs.device)
# shape [batch_size, k]
selected_draft_probs = draft_probs[batch_indices, probs_indicies,
draft_token_ids]
# shape [batch_size, k]
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]
uniform_rand = torch.rand(batch_size,
k,
dtype=self.probs_dtype,
device=target_probs.device)
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1, ), 1, device=target_probs.device))
accepted = uniform_rand < capped_ratio
return accepted
def _get_recovered_probs(
self,
target_probs: torch.Tensor, # [k, vocab_size]
draft_probs: torch.Tensor, # [k, vocab_size]
) -> torch.Tensor:
r"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
:math:`x` given context :math:`x_1, \dots, x_n` according to the target
model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
according to the draft model:
.. math::
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
where :math:`(f(x))_+` is defined as:
.. math::
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_, k, _ = draft_probs.shape
# shape [batch_size, k, vocab_size]
difference = target_probs - draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f = torch.clamp(difference, min=self._smallest_positive_value)
# shape [batch_size, k, vocab_size]
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
return recovered_probs
@cached_property
def _smallest_positive_value(self) -> float:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return torch.finfo(self.probs_dtype).tiny
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
recovered_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids = bonus_token_ids.squeeze()
batch_size, k = recovered_token_ids.shape
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
assert draft_token_ids_batch_size == draft_batch_size
assert num_draft_token_ids == num_draft_probs
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert all(probs.dtype == self.probs_dtype
for probs in [target_probs, draft_probs])
assert all(token_ids.dtype == self.token_id_dtype
for token_ids in [bonus_token_ids, draft_token_ids])
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
bonus_token_ids: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.jit.script
def _multinomial(
probs: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
...@@ -76,7 +76,7 @@ class Sampler(nn.Module): ...@@ -76,7 +76,7 @@ class Sampler(nn.Module):
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
if do_top_p_top_k: if do_top_p_top_k:
logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps, logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks) sampling_tensors.top_ks)
if do_min_p: if do_min_p:
...@@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, ...@@ -185,27 +185,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
return logits return logits
def _apply_top_p_top_k( def _apply_top_k_top_p(
logits: torch.Tensor, logits: torch.Tensor,
p: torch.Tensor, p: torch.Tensor,
k: torch.Tensor, k: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
logits_sort, logits_idx = logits.sort(dim=-1, descending=True) logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
# Apply top-p. # Apply top-p.
probs_sort = logits_sort.softmax(dim=-1) probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort) probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum > p.unsqueeze_(dim=1) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
# Apply top-k. top_p_mask[:, -1] = False
# Create a mask for the top-k elements. logits_sort.masked_fill_(top_p_mask, -float("inf"))
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
# Final mask.
mask = (top_p_mask | top_k_mask)
logits_sort.masked_fill_(mask, -float("inf"))
# Re-sort the probabilities. # Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1], src = torch.arange(logits_idx.shape[-1],
......
...@@ -33,7 +33,7 @@ _MODELS = { ...@@ -33,7 +33,7 @@ _MODELS = {
"MptForCausalLM": ("mpt", "MPTForCausalLM"), "MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM"), "YiForCausalLM": ("yi", "YiForCausalLM"),
......
...@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput ...@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor] KVCache = Tuple[torch.Tensor, torch.Tensor]
class PhiEmbedding(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
def forward(self, input_ids: torch.LongTensor):
return self.wte(input_ids)
class PhiAttention(nn.Module): class PhiAttention(nn.Module):
def __init__(self, def __init__(self,
...@@ -93,27 +79,22 @@ class PhiAttention(nn.Module): ...@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
# pylint: disable=C0103 # pylint: disable=C0103
self.Wqkv = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
linear_method=linear_method,
)
self.qkv_proj = QKVParallelLinear( self.qkv_proj = QKVParallelLinear(
config.hidden_size, self.hidden_size,
self.head_size, self.head_size,
self.total_num_heads, self.total_num_heads,
bias=False, bias=True,
linear_method=linear_method, linear_method=linear_method,
) )
self.out_proj = RowParallelLinear( self.dense = RowParallelLinear(
self.hidden_size, self.hidden_size,
self.hidden_size, self.hidden_size,
linear_method=linear_method, linear_method=linear_method,
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = config.rotary_dim rotary_dim = int(config.partial_rotary_factor *
(config.hidden_size // config.num_attention_heads))
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
# pylint: disable=C0301 # pylint: disable=C0301
...@@ -136,12 +117,12 @@ class PhiAttention(nn.Module): ...@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
kv_cache: KVCache, kv_cache: KVCache,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k) q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.out_proj(attn_output) output, _ = self.dense(attn_output)
return output return output
...@@ -166,8 +147,7 @@ class PhiMLP(nn.Module): ...@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
linear_method=linear_method, linear_method=linear_method,
) )
quant_config = getattr(linear_method, "quant_config", None) quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config, self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
n_inner)
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states) hidden_states, _ = self.fc1(hidden_states)
...@@ -182,9 +162,9 @@ class PhiLayer(nn.Module): ...@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
config: PretrainedConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None):
super().__init__() super().__init__()
self.ln = nn.LayerNorm(config.hidden_size, self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon) eps=config.layer_norm_eps)
self.mixer = PhiAttention(config, linear_method) self.self_attn = PhiAttention(config, linear_method)
self.mlp = PhiMLP(config, linear_method) self.mlp = PhiMLP(config, linear_method)
def forward( def forward(
...@@ -195,8 +175,8 @@ class PhiLayer(nn.Module): ...@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln(hidden_states) hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.mixer( attn_outputs = self.self_attn(
position_ids=position_ids, position_ids=position_ids,
hidden_states=hidden_states, hidden_states=hidden_states,
kv_cache=kv_cache, kv_cache=kv_cache,
...@@ -215,11 +195,14 @@ class PhiModel(nn.Module): ...@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.embd = PhiEmbedding(config) self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
self.h = nn.ModuleList([ config.hidden_size)
self.layers = nn.ModuleList([
PhiLayer(config, linear_method) PhiLayer(config, linear_method)
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
]) ])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward( def forward(
self, self,
...@@ -228,27 +211,19 @@ class PhiModel(nn.Module): ...@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embd(input_ids) hidden_states = self.embed_tokens(input_ids)
for i in range(self.config.num_hidden_layers): for i in range(self.config.num_hidden_layers):
layer = self.h[i] layer = self.layers[i]
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
kv_caches[i], kv_caches[i],
input_metadata, input_metadata,
) )
return hidden_states
hidden_states = self.final_layernorm(hidden_states)
class PhiCausalLMHead(nn.Module): return hidden_states
def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
class PhiForCausalLM(nn.Module): class PhiForCausalLM(nn.Module):
...@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module): ...@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
self.config = config self.config = config
self.linear_method = linear_method self.linear_method = linear_method
self.transformer = PhiModel(config, linear_method) self.model = PhiModel(config, linear_method)
self.lm_head = PhiCausalLMHead(config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size) self.sampler = Sampler(config.vocab_size)
def forward( def forward(
...@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module): ...@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
kv_caches: List[KVCache], kv_caches: List[KVCache],
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches, hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata) input_metadata)
hidden_states = self.lm_head.ln(hidden_states)
return hidden_states return hidden_states
def sample( def sample(
...@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module): ...@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
head = self.lm_head.linear head = self.lm_head
next_tokens = self.sampler(head.weight, hidden_states, next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias) sampling_metadata, head.bias)
return next_tokens return next_tokens
...@@ -291,16 +269,36 @@ class PhiForCausalLM(nn.Module): ...@@ -291,16 +269,36 @@ class PhiForCausalLM(nn.Module):
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None): revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v")
]
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
# pylint: disable=E1136 # pylint: disable=E1136
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -58,7 +58,9 @@ def in_wsl() -> bool: ...@@ -58,7 +58,9 @@ def in_wsl() -> bool:
def get_ip() -> str: def get_ip() -> str:
return socket.gethostbyname(socket.gethostname()) s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
return s.getsockname()[0]
def get_open_port() -> int: def get_open_port() -> int:
......
...@@ -235,9 +235,11 @@ class ModelRunner: ...@@ -235,9 +235,11 @@ class ModelRunner:
input_block_tables[i, :len(block_table)] = block_table input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.tensor(input_block_tables, device="cuda") block_tables = torch.tensor(input_block_tables, device="cuda")
else: else:
max_block_table_len = (max_context_len + self.block_size -
1) // self.block_size
block_tables = _make_tensor_with_pad( block_tables = _make_tensor_with_pad(
block_tables, block_tables,
max_len=max_context_len, max_len=max_block_table_len,
pad=0, pad=0,
dtype=torch.int, dtype=torch.int,
device="cuda", device="cuda",
...@@ -504,7 +506,9 @@ class ModelRunner: ...@@ -504,7 +506,9 @@ class ModelRunner:
"use '--enforce-eager' in the CLI.") "use '--enforce-eager' in the CLI.")
logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. "
"If you are running out of memory, consider decreasing " "If you are running out of memory, consider decreasing "
"`gpu_memory_utilization` or enforcing eager mode.") "`gpu_memory_utilization` or enforcing eager mode. "
"You can also reduce the `max_num_seqs` as needed "
"to decrease memory usage.")
start_time = time.perf_counter() start_time = time.perf_counter()
# Prepare dummy inputs. These will be reused for all batch sizes. # Prepare dummy inputs. These will be reused for all batch sizes.
...@@ -517,9 +521,15 @@ class ModelRunner: ...@@ -517,9 +521,15 @@ class ModelRunner:
context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda()
block_tables = torch.from_numpy(self.graph_block_tables).cuda() block_tables = torch.from_numpy(self.graph_block_tables).cuda()
graph_batch_size = _get_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
# NOTE: Capturing the largest batch size first may help reduce the # NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph. # memory usage of CUDA graph.
for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): for batch_size in reversed(batch_size_capture_list):
# Create dummy input_metadata. # Create dummy input_metadata.
input_metadata = InputMetadata( input_metadata = InputMetadata(
is_prompt=False, is_prompt=False,
......
...@@ -87,6 +87,14 @@ class Worker: ...@@ -87,6 +87,14 @@ class Worker:
gpu_memory_utilization: float, gpu_memory_utilization: float,
cpu_swap_space: int, cpu_swap_space: int,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
"""Profiles the peak memory usage of the model and returns the maximum
number of GPU and CPU cache blocks that can be allocated.
Args:
block_size: The size of the cache block.
gpu_memory_utilization: The fraction of the total GPU memory to use.
cpu_swap_space: The size of the CPU swap space in bytes.
"""
# Profile the memory usage of the model and get the maximum number of # Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory. # cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -231,4 +239,6 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): ...@@ -231,4 +239,6 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
raise ValueError( raise ValueError(
"Bfloat16 is only supported on GPUs with compute capability " "Bfloat16 is only supported on GPUs with compute capability "
f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability "
f"{compute_capability[0]}.{compute_capability[1]}.") f"{compute_capability[0]}.{compute_capability[1]}. "
"You can use float16 instead by explicitly setting the"
"`dtype` flag in CLI, for example: --dtype=half.")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment