Commit 51679bbd authored by zhuwenwen's avatar zhuwenwen
Browse files

resolve merge confilcts

parents 4095d0db 1af090b5
from typing import Tuple, Optional
from functools import cached_property
import torch
import torch.nn as nn
import torch.jit
class RejectionSampler(nn.Module):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def __init__(self, strict_mode: bool = False):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self.probs_dtype = torch.float32
self.token_id_dtype = torch.int64
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self._num_bonus_tokens = 1
self.num_accepted_tokens: Optional[torch.Tensor] = None
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0
def init_gpu_tensors(self, rank: int) -> None:
assert self.num_accepted_tokens is None
device = f"cuda:{rank}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
self.num_emitted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
def forward(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> torch.Tensor:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if self._strict_mode:
self._raise_if_incorrect_shape(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_incorrect_dtype(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_inconsistent_device(target_probs, bonus_token_ids,
draft_probs, draft_token_ids)
self._raise_if_out_of_bounds_vocab(target_probs.shape[-1],
bonus_token_ids,
draft_token_ids)
accepted, recovered_token_ids = self._batch_modified_rejection_sampling(
target_probs,
draft_probs,
draft_token_ids,
)
output_token_ids = self._create_output(
accepted,
recovered_token_ids,
draft_token_ids,
bonus_token_ids,
)
return output_token_ids
def _batch_modified_rejection_sampling(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size, k, vocab_size = draft_probs.shape
# shape [batch_size, k]
accepted = self._get_accepted(target_probs, draft_probs,
draft_token_ids)
recovered_probs = self._get_recovered_probs(
target_probs, draft_probs).reshape(batch_size * k, vocab_size)
recovered_token_ids = _multinomial(recovered_probs,
num_samples=1).reshape(
batch_size, k)
return accepted, recovered_token_ids
def _get_accepted(
self,
target_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_probs: torch.Tensor, # [batch_size, k, vocab_size]
draft_token_ids: torch.Tensor, # [batch_size, k]
) -> torch.Tensor:
r"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
:math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
same conditional probability according to the draft model, the token
is accepted with probability:
.. math::
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size, k, _ = draft_probs.shape
batch_indices = torch.arange(batch_size,
device=target_probs.device)[:, None]
probs_indicies = torch.arange(k, device=target_probs.device)
# shape [batch_size, k]
selected_draft_probs = draft_probs[batch_indices, probs_indicies,
draft_token_ids]
# shape [batch_size, k]
selected_target_probs = target_probs[batch_indices, probs_indicies,
draft_token_ids]
uniform_rand = torch.rand(batch_size,
k,
dtype=self.probs_dtype,
device=target_probs.device)
capped_ratio = torch.minimum(
selected_target_probs / selected_draft_probs,
torch.full((1, ), 1, device=target_probs.device))
accepted = uniform_rand < capped_ratio
return accepted
def _get_recovered_probs(
self,
target_probs: torch.Tensor, # [k, vocab_size]
draft_probs: torch.Tensor, # [k, vocab_size]
) -> torch.Tensor:
r"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
:math:`x` given context :math:`x_1, \dots, x_n` according to the target
model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
according to the draft model:
.. math::
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
where :math:`(f(x))_+` is defined as:
.. math::
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_, k, _ = draft_probs.shape
# shape [batch_size, k, vocab_size]
difference = target_probs - draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f = torch.clamp(difference, min=self._smallest_positive_value)
# shape [batch_size, k, vocab_size]
recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1)
return recovered_probs
@cached_property
def _smallest_positive_value(self) -> float:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return torch.finfo(self.probs_dtype).tiny
def _create_output(
self,
accepted: torch.Tensor, # [batch_size, k]
recovered_token_ids: torch.Tensor, # [batch_size, k]
draft_token_ids: torch.Tensor, # [batch_size, k]
bonus_token_ids: torch.Tensor, # [batch_size]
) -> torch.Tensor:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids = bonus_token_ids.squeeze()
batch_size, k = recovered_token_ids.shape
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
# Create masks using the indices.
indices = torch.arange(k, device=accepted.device).unsqueeze(0)
accepted_mask = indices < limits.unsqueeze(1)
after_false_mask = indices == limits.unsqueeze(1)
# Create an extended output tensor
output_with_bonus_tokens = -torch.ones(
(batch_size, k + self._num_bonus_tokens),
dtype=self.token_id_dtype,
device=accepted.device)
output = output_with_bonus_tokens[:, :k]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output[:, :k] = torch.where(accepted_mask, draft_token_ids,
-torch.ones_like(draft_token_ids))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
recovered_token_ids.mul(after_false_mask))
self.num_accepted_tokens += accepted.sum()
self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum()
self.num_draft_tokens += batch_size * k
return output_with_bonus_tokens
def _raise_if_incorrect_shape(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
(target_batch_size, num_target_probs,
target_vocab_size) = target_probs.shape
bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape
draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape
draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape
assert draft_batch_size == target_batch_size
assert num_draft_probs == num_target_probs
assert (draft_vocab_size == target_vocab_size
), f"{draft_vocab_size=} {target_vocab_size=}"
assert draft_token_ids_batch_size == draft_batch_size
assert num_draft_token_ids == num_draft_probs
assert bonus_batch_size == target_batch_size
assert num_bonus_tokens == self._num_bonus_tokens
def _raise_if_incorrect_dtype(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert all(probs.dtype == self.probs_dtype
for probs in [target_probs, draft_probs])
assert all(token_ids.dtype == self.token_id_dtype
for token_ids in [bonus_token_ids, draft_token_ids])
def _raise_if_inconsistent_device(
self,
target_probs: torch.Tensor,
bonus_token_ids: torch.Tensor,
draft_probs: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
devices = [
t.device for t in
[target_probs, bonus_token_ids, draft_probs, draft_token_ids]
]
assert all([devices[0] == device for device in devices])
def _raise_if_out_of_bounds_vocab(
self,
vocab_size: int,
bonus_token_ids: torch.Tensor,
draft_token_ids: torch.Tensor,
) -> None:
assert torch.all(bonus_token_ids < vocab_size)
assert torch.all(bonus_token_ids >= 0)
assert torch.all(draft_token_ids < vocab_size)
assert torch.all(draft_token_ids >= 0)
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@torch.jit.script
def _multinomial(
probs: torch.Tensor,
num_samples: int,
) -> torch.Tensor:
if num_samples > 1:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs = probs[:, None, :].expand(probs.shape[0], num_samples,
probs.shape[1]).contiguous().view(
-1, probs.shape[1])
q = torch.empty_like(probs).exponential_(1.0)
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
......@@ -27,9 +27,25 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
def __init__(self, vocab_size: int) -> None:
def __init__(self,
vocab_size: int,
org_vocab_size: Optional[int] = None) -> None:
super().__init__()
self.vocab_size = vocab_size
# original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :self.org_vocab_size]
return logits
def forward(
self,
......@@ -42,8 +58,7 @@ class Sampler(nn.Module):
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata)
# Get the logits for the next tokens.
logits = _get_logits(hidden_states, embedding, embedding_bias,
self.vocab_size)
logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
......@@ -76,7 +91,7 @@ class Sampler(nn.Module):
logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1))
if do_top_p_top_k:
logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps,
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if do_min_p:
......@@ -98,20 +113,6 @@ class Sampler(nn.Module):
prompt_logprobs, sample_logprobs)
def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
embedding_bias: Optional[torch.Tensor],
vocab_size: int) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias
logits = tensor_model_parallel_gather(logits)
# Remove paddings in vocab (if any).
if logits is not None:
logits = logits[:, :vocab_size]
return logits
def _prune_hidden_states(
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
......@@ -185,27 +186,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
return logits
def _apply_top_p_top_k(
def _apply_top_k_top_p(
logits: torch.Tensor,
p: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
logits_sort, logits_idx = logits.sort(dim=-1, descending=True)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
# Apply top-k.
top_k_mask = logits_sort.size(1) - k.to(torch.long)
# Get all the top_k values.
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
top_k_mask = logits_sort < top_k_mask
logits_sort.masked_fill_(top_k_mask, -float("inf"))
# Apply top-p.
probs_sort = logits_sort.softmax(dim=-1)
probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort)
top_p_mask = probs_sum > p.unsqueeze_(dim=1)
# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device)
top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1)
top_k_mask = top_k_mask >= k.unsqueeze_(dim=1)
# Final mask.
mask = (top_p_mask | top_k_mask)
logits_sort.masked_fill_(mask, -float("inf"))
probs_sum = probs_sort.cumsum(dim=-1)
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
# at least one
top_p_mask[:, -1] = False
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1],
......
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import torch
import triton
import triton.language as tl
if triton.__version__ >= "2.1.0":
@triton.jit
def _fwd_kernel(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# # update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
p = tl.exp(qk - m_ij[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
m_i_new = tl.maximum(m_i, m_ij)
alpha = tl.exp(m_i - m_i_new)
beta = tl.exp(m_ij - m_i_new)
l_i_new = alpha * l_i + beta * l_ij
# -- update output accumulator --
# scale p
p_scale = beta / l_i_new
p = p * p_scale[:, None]
# scale acc
acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_flash_attn_v2(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@triton.jit
def _fwd_kernel_alibi(
Q,
K,
V,
K_cache,
V_cache,
B_Loc,
sm_scale,
B_Start_Loc,
B_Seqlen,
B_Ctxlen,
Alibi_slopes,
block_size,
x,
Out,
stride_b_loc_b,
stride_b_loc_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_obs,
stride_oh,
stride_od,
stride_k_cache_bs,
stride_k_cache_h,
stride_k_cache_d,
stride_k_cache_bl,
stride_k_cache_x,
stride_v_cache_bs,
stride_v_cache_h,
stride_v_cache_d,
stride_v_cache_bl,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# attn_bias[]
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
start_m = tl.program_id(2)
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
block_start_loc = BLOCK_M * start_m
# initialize offsets
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
off_q = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs +
cur_head * stride_qh + offs_d[None, :] * stride_qd)
q = tl.load(
Q + off_q,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
# # initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = 0
for start_n in range(0, cur_batch_ctx_len, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
bn = tl.load(B_Loc + cur_batch * stride_b_loc_b +
((start_n + offs_n) // block_size) * stride_b_loc_s,
mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0)
off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x)
off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k,
mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk,
float("-inf"))
qk *= sm_scale
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(V_cache + off_v,
mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh +
offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh +
offs_d[None, :] * stride_vd)
k_ptrs = K + off_k
v_ptrs = V + off_v
block_mask = tl.where(
block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0)
# init alibi
alibi_slope = tl.load(Alibi_slopes + cur_head)
alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len
# # init debuger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(k_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_kbs,
mask=(start_n + offs_n[None, :]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k, allow_tf32=False)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk,
float("-inf"))
# load alibi
alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k -
alibi_start_q[:, None]) * alibi_slope
alibi = tl.where(
(alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len),
alibi, float("-inf"))
qk += alibi
alibi_start_k += BLOCK_N
# -- compute m_ij, p, l_ij
m_ij = tl.max(qk, 1)
m_i_new = tl.maximum(m_i, m_ij)
p = tl.math.exp(qk - m_i_new[:, None])
l_ij = tl.sum(p, 1)
# -- update m_i and l_i
alpha = tl.math.exp(m_i - m_i_new)
l_i_new = alpha * l_i + l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale = alpha
# acc_scale = l_i / l_i_new * alpha
acc = acc * acc_scale[:, None]
# update acc
v = tl.load(v_ptrs +
(cur_batch_in_all_start_index + start_n) * stride_vbs,
mask=(start_n + offs_n[:, None]) <
cur_batch_seq_len - cur_batch_ctx_len,
other=0.0)
p = p.to(v.dtype)
acc += tl.dot(p, v, allow_tf32=False)
# update m_i and l_i
l_i = l_i_new
m_i = m_i_new
acc = acc / l_i[:, None]
# initialize pointers to output
off_o = (
(cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o
tl.store(out_ptrs,
acc,
mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len)
return
@torch.inference_mode()
def context_attention_fwd(q,
k,
v,
o,
k_cache,
v_cache,
b_loc,
b_start_loc,
b_seq_len,
b_ctx_len,
max_input_len,
alibi_slopes=None):
BLOCK = 128
# shape constraints
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
num_warps = 8 if Lk <= 64 else 8
if alibi_slopes is not None:
_fwd_kernel_alibi[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
alibi_slopes,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4
), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
_fwd_kernel[grid](
q,
k,
v,
k_cache,
v_cache,
b_loc,
sm_scale,
b_start_loc,
b_seq_len,
b_ctx_len,
v_cache.shape[3],
8,
o,
b_loc.stride(0),
b_loc.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
o.stride(0),
o.stride(1),
o.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
k_cache.stride(
4), #[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache.stride(0),
v_cache.stride(1),
v_cache.stride(2),
v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return
......@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.utils import set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE = 64
def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int:
def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
return ((vocab_size + pad_to - 1) // pad_to) * pad_to
......@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None):
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings)
self.org_vocab_size = org_num_embeddings or num_embeddings
self.num_embeddings_padded = pad_vocab_size(num_embeddings,
padding_size)
self.embedding_dim = embedding_dim
if params_dtype is None:
params_dtype = torch.get_default_dtype()
......@@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module):
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
parallel_dim = param.parallel_dim
assert loaded_weight.shape[parallel_dim] == self.num_embeddings
assert loaded_weight.shape[parallel_dim] == self.org_vocab_size
loaded_weight = loaded_weight[self.vocab_start_index:self.
vocab_end_index]
param[:loaded_weight.shape[0]].data.copy_(loaded_weight)
......@@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding):
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
bias: bool = False,
params_dtype: Optional[torch.dtype] = None):
super().__init__(num_embeddings, embedding_dim, params_dtype)
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE):
super().__init__(num_embeddings, embedding_dim, params_dtype,
org_num_embeddings, padding_size)
if bias:
self.bias = Parameter(
torch.empty(self.num_embeddings_per_partition,
......
"""Utilities for selecting and loading models."""
import contextlib
from typing import Type
from typing import Optional, Type
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.config import ModelConfig, LoRAConfig
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
......@@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if (model_config.quantization is not None
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
for arch in architectures:
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
......@@ -32,8 +37,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig) -> nn.Module:
model_class = _get_model_architecture(model_config.hf_config)
def get_model(model_config: ModelConfig,
lora_config: Optional[LoRAConfig] = None) -> nn.Module:
model_class = _get_model_architecture(model_config)
# Get the (maybe quantized) linear method.
linear_method = None
......@@ -62,7 +68,17 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Create a model instance.
# The weights will be initialized as empty tensors.
with torch.device("cuda"):
model = model_class(model_config.hf_config, linear_method)
if getattr(model_class, "supports_lora", False):
model = model_class(model_config.hf_config, linear_method,
lora_config)
elif lora_config:
raise ValueError(
f"Model {model_class.__name__} does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github.")
else:
model = model_class(model_config.hf_config, linear_method)
if model_config.load_format == "dummy":
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
......
......@@ -18,6 +18,7 @@ _MODELS = {
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
......@@ -29,14 +30,17 @@ _MODELS = {
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"YiForCausalLM": ("yi", "YiForCausalLM")
}
# Models not supported by ROCm.
......@@ -45,6 +49,8 @@ _ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Qwen2ForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MistralForCausalLM":
"Sliding window attention is not yet supported in ROCm's flash attention",
"MixtralForCausalLM":
......
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from typing import Any, Dict, List, Optional, Tuple
import torch
from torch import nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
ReplicatedLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class DeepseekMLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
reduce_results: bool = True,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method,
reduce_results=reduce_results)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekMoE(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.n_routed_experts = config.n_routed_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.n_routed_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.n_routed_experts}.")
self.experts = nn.ModuleList([
DeepseekMLP(hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
reduce_results=False)
for idx in range(self.n_routed_experts)
])
self.pack_params()
self.gate = ReplicatedLinear(config.hidden_size,
self.n_routed_experts,
bias=False,
linear_method=None)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
reduce_results=False,
)
def pack_params(self):
w1 = []
w2 = []
for expert in self.experts:
w1.append(expert.gate_up_proj.weight)
w2.append(expert.down_proj.weight)
self.w1 = torch._utils._flatten_dense_tensors(w1)
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1)
for data, param in zip(w1s, w1):
param.data = data
self.w1 = self.w1.view(len(w1), *w1s[0].shape)
self.w2 = torch._utils._flatten_dense_tensors(w2)
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2)
for data, param in zip(w2s, w2):
param.data = data
self.w2 = self.w2.view(len(w2), *w2s[0].shape)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if self.config.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
if self.config.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = fused_moe(hidden_states,
self.w1,
self.w2,
routing_weights,
selected_experts,
inplace=True)
if self.config.n_shared_experts is not None:
final_hidden_states = final_hidden_states + shared_output
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states.view(batch_size, sequence_length,
hidden_dim)
class DeepseekAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class DeepseekDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = DeepseekAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
linear_method=linear_method,
)
if (config.n_routed_experts is not None and \
layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekMoE(config=config, linear_method=linear_method)
else:
self.mlp = DeepseekMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class DeepseekModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
DeepseekDecoderLayer(config,
layer_idx,
linear_method=linear_method)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class DeepseekForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = DeepseekModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if (("mlp.experts." in name or "mlp.shared_experts." in name)
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -225,14 +226,19 @@ class LlamaModel(nn.Module):
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
LlamaDecoderLayer(config, linear_method)
......@@ -263,18 +269,31 @@ class LlamaModel(nn.Module):
class LlamaForCausalLM(nn.Module):
supports_lora = True
def __init__(
self,
config: LlamaConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = LlamaModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.model = LlamaModel(config, linear_method, lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
def forward(
self,
......
......@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
from vllm.config import LoRAConfig
KVCache = Tuple[torch.Tensor, torch.Tensor]
......@@ -220,15 +221,20 @@ class MistralModel(nn.Module):
self,
config: MistralConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
lora_vocab = (lora_config.lora_extra_vocab_size *
(lora_config.max_loras or 1)) if lora_config else 0
self.vocab_size = config.vocab_size + lora_vocab
self.org_vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
self.vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
)
self.layers = nn.ModuleList([
MistralDecoderLayer(config, linear_method)
......@@ -259,18 +265,33 @@ class MistralModel(nn.Module):
class MistralForCausalLM(nn.Module):
supports_lora = True
def __init__(
self,
config: MistralConfig,
linear_method: Optional[LinearMethodBase] = None,
lora_config: Optional[LoRAConfig] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MistralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
self.model = MistralModel(config,
linear_method,
lora_config=lora_config)
unpadded_vocab_size = config.vocab_size
if lora_config:
unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.lm_head = ParallelLMHead(
unpadded_vocab_size,
config.hidden_size,
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if not lora_config else lora_config.lora_vocab_padding_size,
)
self.sampler = Sampler(unpadded_vocab_size, config.vocab_size)
def forward(
self,
......
......@@ -23,8 +23,6 @@
"""Inference-only Mixtral model."""
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
......@@ -33,10 +31,11 @@ from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
......@@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
......@@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MixtralMLP(nn.Module):
class MixtralMoE(nn.Module):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size // tp_size
class MixtralMoE(nn.Module):
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.gate = ReplicatedLinear(self.hidden_size,
self.num_total_experts,
bias=False,
params_dtype=self.params_dtype,
linear_method=None)
self.ws = nn.Parameter(
torch.empty(self.num_total_experts,
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
dtype=self.params_dtype))
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
dtype=self.params_dtype))
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
})
set_weight_attrs(self.w2s, {
"weight_loader": self.weight_loader,
})
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor,
weight_name: str, expert_id: int):
tp_rank = get_tensor_model_parallel_rank()
param_data = param.data
shard_size = self.intermediate_size
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
if weight_name.endswith("w1.weight"):
param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w3.weight"):
param_data[expert_id,
shard_size:2 * shard_size, :] = loaded_weight[shard, :]
if weight_name.endswith("w2.weight"):
param_data[expert_id, :, :] = loaded_weight[:, shard]
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
batch_size, sequence_length, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, self.hidden_size)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
......@@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
final_hidden_states = fused_moe(hidden_states,
self.ws,
self.w2s,
routing_weights,
selected_experts,
inplace=True)
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
return final_hidden_states.view(batch_size, sequence_length,
hidden_size)
class MixtralAttention(nn.Module):
......@@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
......@@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = [
# (param_name, weight_name, expert_id)
("ws" if weight_name in ["w1", "w3"] else "w2s",
f"experts.{expert_id}.{weight_name}.weight", expert_id)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
......@@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
......@@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for param_name, weight_name, expert_id in expert_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
expert_id=expert_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from typing import List, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import MixtralConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class MixtralMLP(nn.Module):
def __init__(
self,
num_experts: int,
hidden_size: int,
intermediate_size: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.num_experts = num_experts
self.ffn_dim = intermediate_size
self.hidden_dim = hidden_size
self.w1 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
self.w2 = ReplicatedLinear(self.ffn_dim,
self.hidden_dim,
bias=False,
linear_method=linear_method)
self.w3 = ReplicatedLinear(self.hidden_dim,
self.ffn_dim,
bias=False,
linear_method=linear_method)
# TODO: Use vllm's SiluAndMul
self.act_fn = nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
w1_out, _ = self.w1(hidden_states)
w1_out = self.act_fn(w1_out)
w3_out, _ = self.w3(hidden_states)
current_hidden_states = w1_out * w3_out
current_hidden_states, _ = self.w2(current_hidden_states)
return current_hidden_states
class MixtralMoE(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
):
super().__init__()
self.config = config
self.rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.num_total_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
if self.tp_size > self.num_total_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {self.num_total_experts}.")
# Split experts equally between ranks
self.expert_indicies = np.array_split(range(
self.num_total_experts), self.tp_size)[self.rank].tolist()
if not self.expert_indicies:
raise ValueError(
f"Rank {self.rank} has no experts assigned to it.")
self.experts = nn.ModuleList([
MixtralMLP(self.num_total_experts,
config.hidden_size,
config.intermediate_size,
linear_method=linear_method)
if idx in self.expert_indicies else None
for idx in range(self.num_total_experts)
])
self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts,
bias=False,
linear_method=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits, _ = self.gate(hidden_states)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,
self.top_k,
dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
final_hidden_states = None
for expert_idx in self.expert_indicies:
expert_layer = self.experts[expert_idx]
expert_mask = (selected_experts == expert_idx)
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
keepdim=True)
current_hidden_states = expert_layer(hidden_states).mul_(
expert_weights)
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states.add_(current_hidden_states)
return tensor_model_parallel_all_reduce(final_hidden_states).view(
batch_size, sequence_length, hidden_dim)
class MixtralAttention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = PagedAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class MixtralDecoderLayer(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 10000)
self.self_attn = MixtralAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
sliding_window=config.sliding_window,
linear_method=linear_method)
self.block_sparse_moe = MixtralMoE(config=config,
linear_method=linear_method)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.block_sparse_moe(hidden_states)
return hidden_states, residual
class MixtralModel(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
MixtralDecoderLayer(config, linear_method=linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions, hidden_states,
kv_caches[i], input_metadata,
residual)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class MixtralForCausalLM(nn.Module):
def __init__(
self,
config: MixtralConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = MixtralModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: Optional[torch.Tensor],
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path,
cache_dir,
load_format,
revision,
fall_back_to_pt=False):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip experts that are not assigned to this worker.
if ("block_sparse_moe.experts." in name
and name not in params_dict):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
......@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class PhiEmbedding(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
self.wte = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
def forward(self, input_ids: torch.LongTensor):
return self.wte(input_ids)
class PhiAttention(nn.Module):
def __init__(self,
......@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
tensor_model_parallel_world_size)
# pylint: disable=C0103
self.Wqkv = QKVParallelLinear(
self.hidden_size,
self.head_size,
self.total_num_heads,
linear_method=linear_method,
)
self.qkv_proj = QKVParallelLinear(
config.hidden_size,
self.hidden_size,
self.head_size,
self.total_num_heads,
bias=False,
bias=True,
linear_method=linear_method,
)
self.out_proj = RowParallelLinear(
self.dense = RowParallelLinear(
self.hidden_size,
self.hidden_size,
linear_method=linear_method,
)
scaling = self.head_size**-0.5
rotary_dim = config.rotary_dim
rotary_dim = int(config.partial_rotary_factor *
(config.hidden_size // config.num_attention_heads))
assert rotary_dim % 2 == 0
# pylint: disable=C0301
......@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.Wqkv(hidden_states)
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(position_ids, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.out_proj(attn_output)
output, _ = self.dense(attn_output)
return output
......@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
linear_method=linear_method,
)
quant_config = getattr(linear_method, "quant_config", None)
self.act = get_act_fn(config.activation_function, quant_config,
n_inner)
self.act = get_act_fn(config.hidden_act, quant_config, n_inner)
def forward(self, hidden_states):
hidden_states, _ = self.fc1(hidden_states)
......@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.mixer = PhiAttention(config, linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
self.self_attn = PhiAttention(config, linear_method)
self.mlp = PhiMLP(config, linear_method)
def forward(
......@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln(hidden_states)
attn_outputs = self.mixer(
hidden_states = self.input_layernorm(hidden_states)
attn_outputs = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
kv_cache=kv_cache,
......@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
super().__init__()
self.config = config
self.linear_method = linear_method
self.embd = PhiEmbedding(config)
self.h = nn.ModuleList([
self.embed_tokens = VocabParallelEmbedding(config.vocab_size,
config.hidden_size)
self.layers = nn.ModuleList([
PhiLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.final_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_eps)
def forward(
self,
......@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embd(input_ids)
hidden_states = self.embed_tokens(input_ids)
for i in range(self.config.num_hidden_layers):
layer = self.h[i]
layer = self.layers[i]
hidden_states = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
)
return hidden_states
class PhiCausalLMHead(nn.Module):
hidden_states = self.final_layernorm(hidden_states)
def __init__(self, config: PretrainedConfig):
super().__init__()
self.ln = nn.LayerNorm(config.hidden_size,
eps=config.layer_norm_epsilon)
self.linear = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
return hidden_states
class PhiForCausalLM(nn.Module):
......@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
self.config = config
self.linear_method = linear_method
self.transformer = PhiModel(config, linear_method)
self.lm_head = PhiCausalLMHead(config)
self.model = PhiModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
bias=True)
self.sampler = Sampler(config.vocab_size)
def forward(
......@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, kv_caches,
input_metadata)
hidden_states = self.lm_head.ln(hidden_states)
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
......@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
head = self.lm_head.linear
head = self.lm_head
next_tokens = self.sampler(head.weight, hidden_states,
sampling_metadata, head.bias)
return next_tokens
......@@ -291,17 +269,37 @@ class PhiForCausalLM(nn.Module):
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v")
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import Qwen2Config
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class Qwen2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
linear_method=linear_method)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class Qwen2Attention(nn.Module):
def __init__(self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
max_position: int = 4096 * 32,
rope_theta: float = 10000,
use_sliding_window: bool = False,
linear_method: Optional[LinearMethodBase] = None,
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.sliding_window = sliding_window if use_sliding_window else None
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=True,
linear_method=linear_method,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
linear_method=linear_method,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=self.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class Qwen2DecoderLayer(nn.Module):
def __init__(
self,
config: Qwen2Config,
layer_idx: int,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
# Requires transformers > 4.32.0
rope_theta = getattr(config, "rope_theta", 1000000)
use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers
self.self_attn = Qwen2Attention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
max_position=config.max_position_embeddings,
num_kv_heads=config.num_key_value_heads,
rope_theta=rope_theta,
use_sliding_window=use_sliding_window,
linear_method=linear_method,
sliding_window=config.sliding_window)
self.mlp = Qwen2MLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen2Model(nn.Module):
def __init__(
self,
config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2DecoderLayer(config, layer_idx, linear_method)
for layer_idx in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class Qwen2ForCausalLM(nn.Module):
def __init__(
self,
config: Qwen2Config,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
from typing import List, Optional, Tuple
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import PagedAttention
from vllm.model_executor.layers.linear import (LinearMethodBase,
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding, ParallelLMHead)
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
KVCache = Tuple[torch.Tensor, torch.Tensor]
class StablelmMLP(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size, [config.intermediate_size] * 2,
bias=False,
linear_method=linear_method)
self.down_proj = RowParallelLinear(config.intermediate_size,
config.hidden_size,
bias=False)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor) -> torch.Tensor:
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class StablelmAttention(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = config.num_attention_heads
self.num_heads = self.total_num_heads // tp_size
self.total_num_key_value_heads = config.num_key_value_heads
if self.total_num_key_value_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_key_value_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_key_value_heads == 0
self.num_key_value_heads = max(
1, self.total_num_key_value_heads // tp_size)
self.head_dim = self.hidden_size // self.total_num_heads
self.max_position_embeddings = config.max_position_embeddings
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
self.scaling = self.head_dim**-0.5
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.qkv_bias = getattr(config, "use_qkv_bias", False)
if (self.head_dim * self.num_heads * tp_size) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads}).")
self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_key_value_heads,
self.qkv_bias,
linear_method=linear_method)
self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim,
self.hidden_size,
bias=False,
linear_method=linear_method)
self.rotary_ndims = int(self.head_dim * self.config.rope_pct)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_ndims,
max_position=self.config.max_position_embeddings,
base=self.config.rope_theta,
)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
k_cache, v_cache = kv_cache
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata)
output, _ = self.o_proj(attn_output)
return output
class StablelmDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.self_attn = StablelmAttention(config)
self.mlp = StablelmMLP(config, linear_method)
self.input_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps)
self.post_attention_layernorm = nn.LayerNorm(config.hidden_size,
eps=config.norm_eps)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, residual
class StableLMEpochModel(nn.Module):
def __init__(self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None) -> None:
super().__init__()
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
StablelmDecoderLayer(config, linear_method)
for _ in range(config.num_hidden_layers)
])
self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
input_metadata,
)
hidden_states = self.norm(hidden_states)
return hidden_states
class StablelmForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
self.config = config
self.linear_method = linear_method
self.model = StableLMEpochModel(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[KVCache],
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, kv_caches,
input_metadata)
return hidden_states
def sample(
self,
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
or "rotary_emb.sin_cached" in name):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
from collections import namedtuple
from typing import Any, Dict, List, Optional, Union
from torch.distributed import ProcessGroup
import torch
from vllm.model_executor.parallel_utils.parallel_state import (
......@@ -5,23 +10,34 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size,
get_tensor_model_parallel_group,
)
from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce
def tensor_model_parallel_all_reduce(input_):
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group.
NOTE: This operation is applied in-place on the input tensor.
NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
# All-reduce.
out = custom_all_reduce(input_)
if out is not None:
return out
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_, dim=-1):
def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
......@@ -48,7 +64,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
return output_tensor
def tensor_model_parallel_gather(input_, dst=0, dim=-1):
def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0,
dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
......@@ -80,27 +98,101 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
return output_tensor
def broadcast(input_, src=0):
def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src)
torch.distributed.broadcast(input_, src=src, group=group)
return input_
def broadcast_object_list(obj_list, src=0):
def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list."""
world_size = torch.distributed.get_world_size()
assert 0 <= src < world_size, f"Invalid src rank ({src})"
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src)
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
) -> Dict[Any, Union[torch.Tensor, Any]]:
"""Broadcast the input tensor dictionary."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return tensor_dict
rank = torch.distributed.get_rank()
if rank == src:
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
assert value.is_cuda, (
f"Tensor {key}: {value} is not on cuda. Currently we only "
f"support broadcasting tensors on cuda.")
metadata_list.append(
(key, TensorMetadata(value.dtype, value.size())))
else:
metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = tensor_dict[key]
torch.distributed.broadcast(tensor, src=src)
else:
recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list,
src=src,
group=group)
metadata_list = recv_metadata_list[0]
tensor_dict = {}
async_handles = []
for key, value in metadata_list:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device="cuda")
async_handle = torch.distributed.broadcast(tensor,
src=src,
async_op=True,
group=group)
async_handles.append(async_handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict
from contextlib import contextmanager
from typing import Optional
import torch
import torch.distributed as dist
from vllm.logger import init_logger
from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank)
try:
from vllm._C import custom_ar
import pynvml
except ImportError:
# For AMD GPUs
custom_ar = None
pynvml = None
logger = init_logger(__name__)
_CA_HANDLE = None
_IS_CAPTURING = False
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
def init_custom_ar() -> None:
global _CA_HANDLE
if _CA_HANDLE is not None:
return
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in _SUPPORTED_WORLD_SIZES:
logger.warn(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To slience this warning, specify"
"disable_custom_all_reduce=True explicitly.", world_size,
str(_SUPPORTED_WORLD_SIZES))
return
if not _can_p2p(rank, world_size):
logger.warn(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability. To slience this warning, specify"
"disable_custom_all_reduce=True explicitly.")
return
_CA_HANDLE = CustomAllreduce(rank, world_size)
def begin_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = True
def end_capture() -> None:
global _IS_CAPTURING
_IS_CAPTURING = False
def is_capturing() -> bool:
return _IS_CAPTURING and _CA_HANDLE is not None
def get_handle() -> Optional["CustomAllreduce"]:
return _CA_HANDLE
@contextmanager
def capture():
try:
begin_capture()
yield
finally:
end_capture()
handle = get_handle()
if handle is not None:
handle.register_graph_buffers()
def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]:
ca_handle = get_handle()
# when custom allreduce is disabled, this will be None
if ca_handle is None:
return
if is_capturing():
if torch.cuda.is_current_stream_capturing():
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_reg(input)
else:
if ca_handle.should_custom_ar(input):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return torch.empty_like(input)
else:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if ca_handle.should_custom_ar(input):
return ca_handle.all_reduce_unreg(input)
@contextmanager
def _nvml():
try:
pynvml.nvmlInit()
yield
finally:
pynvml.nvmlShutdown()
# query if the set of gpus are fully connected by nvlink (1 hop)
@_nvml()
def _is_full_nvlink(rank, world_size):
handle = pynvml.nvmlDeviceGetHandleByIndex(rank)
for i in range(world_size):
if i != rank:
try:
link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i)
if not link_state:
return False
except pynvml.NVMLError as error:
logger.info(
f"NVLink detection failed with message \"{str(error)}\". "
"This is normal if your machine has no NVLink equipped")
return False
return True
def _can_p2p(rank: int, world_size: int) -> bool:
for i in range(world_size):
if i == rank:
continue
if not torch.cuda.can_device_access_peer(rank, i):
return False
return True
class CustomAllreduce:
# max_size: max supported allreduce size
def __init__(self, rank, world_size, max_size=8192 * 1024) -> None:
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self.meta = torch.zeros(custom_ar.meta_size() + max_size,
dtype=torch.uint8,
device="cuda")
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda")
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device="cuda")
self.max_size = max_size
self.world_size = world_size
handles, offsets = self._get_ipc_meta(self.meta)
self.full_nvlink = _is_full_nvlink(rank, world_size)
self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data,
handles, offsets, rank,
self.full_nvlink)
self.fast_cond = self.full_nvlink or world_size <= 2
self.register_buffer(self.buffer)
def _get_ipc_meta(self, inp: torch.Tensor):
data = inp.untyped_storage()._share_cuda_()
shard_data = (
data[1], # ipc handle to base ptr
data[3], # offset of base ptr
)
return self._gather_ipc_meta(shard_data)
def _gather_ipc_meta(self, shard_data):
all_data = [None] * self.world_size
dist.all_gather_object(all_data, shard_data)
handles = []
offsets = []
for i in range(len(all_data)):
handles.append(all_data[i][0])
offsets.append(all_data[i][1])
return handles, offsets
def register_buffer(self, inp: torch.Tensor):
handles, offsets = self._get_ipc_meta(inp)
custom_ar.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
custom_ar.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
return custom_ar.should_custom_ar(inp, self.max_size, self.world_size,
self.full_nvlink)
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_reg(self._ptr, inp, out)
return out
# all reduce, assuming inp tensor is NOT IPC registered
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
if out is None:
out = torch.empty_like(inp)
custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out)
return out
def close(self):
if self._ptr:
custom_ar.dispose(self._ptr)
self._ptr = 0
def __del__(self):
self.close()
......@@ -83,6 +83,31 @@ def initialize_model_parallel(
_PIPELINE_GLOBAL_RANKS = ranks
def ensure_model_parallel_initialized(
tensor_model_parallel_size: int,
pipeline_model_parallel_size: int,
) -> None:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
if not model_parallel_is_initialized():
initialize_model_parallel(tensor_model_parallel_size,
pipeline_model_parallel_size)
return
assert (
get_tensor_model_parallel_world_size() == tensor_model_parallel_size
), ("tensor parallel group already initialized, but of unexpected size: "
f"{get_tensor_model_parallel_world_size()=} vs. "
f"{tensor_model_parallel_size=}")
assert (get_pipeline_model_parallel_world_size(
) == pipeline_model_parallel_size), (
"pipeline parallel group already initialized, but of unexpected size: "
f"{get_pipeline_model_parallel_world_size()=} vs. "
f"{pipeline_model_parallel_size=}")
def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (_TENSOR_MODEL_PARALLEL_GROUP is not None
......@@ -92,7 +117,7 @@ def model_parallel_is_initialized():
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
"tenosr model parallel group is not initialized")
"tensor model parallel group is not initialized")
return _TENSOR_MODEL_PARALLEL_GROUP
......@@ -170,10 +195,14 @@ def get_pipeline_model_parallel_prev_rank():
def destroy_model_parallel():
"""Set the groups to none."""
"""Set the groups to none and destroy them."""
global _TENSOR_MODEL_PARALLEL_GROUP
if _TENSOR_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
_TENSOR_MODEL_PARALLEL_GROUP = None
global _PIPELINE_MODEL_PARALLEL_GROUP
if _PIPELINE_MODEL_PARALLEL_GROUP:
torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
_PIPELINE_MODEL_PARALLEL_GROUP = None
global _PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS = None
"""Utilities for downloading and initializing model weights."""
import filelock
import glob
import fnmatch
import json
import os
from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, HfFileSystem
import numpy as np
from safetensors.torch import load_file, save_file, safe_open
import torch
......@@ -149,6 +150,18 @@ def prepare_hf_model_weights(
allow_patterns += ["*.pt"]
if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info(f"Using model weights format {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
......
......@@ -2,6 +2,7 @@ from typing import List, Optional
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup,
SequenceStatus)
from vllm.lora.request import LoRARequest
class CompletionOutput:
......@@ -16,6 +17,7 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
lora_request: The LoRA request that was used to generate the output.
"""
def __init__(
......@@ -26,6 +28,7 @@ class CompletionOutput:
cumulative_logprob: float,
logprobs: Optional[SampleLogprobs],
finish_reason: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.index = index
self.text = text
......@@ -33,6 +36,7 @@ class CompletionOutput:
self.cumulative_logprob = cumulative_logprob
self.logprobs = logprobs
self.finish_reason = finish_reason
self.lora_request = lora_request
def finished(self) -> bool:
return self.finish_reason is not None
......@@ -56,6 +60,7 @@ class RequestOutput:
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
lora_request: The LoRA request that was used to generate the output.
"""
def __init__(
......@@ -66,6 +71,7 @@ class RequestOutput:
prompt_logprobs: Optional[PromptLogprobs],
outputs: List[CompletionOutput],
finished: bool,
lora_request: Optional[LoRARequest] = None,
) -> None:
self.request_id = request_id
self.prompt = prompt
......@@ -73,6 +79,7 @@ class RequestOutput:
self.prompt_logprobs = prompt_logprobs
self.outputs = outputs
self.finished = finished
self.lora_request = lora_request
@classmethod
def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput":
......@@ -108,8 +115,13 @@ class RequestOutput:
prompt_token_ids = seq_group.prompt_token_ids
prompt_logprobs = seq_group.prompt_logprobs
finished = seq_group.is_finished()
return cls(seq_group.request_id, prompt, prompt_token_ids,
prompt_logprobs, outputs, finished)
return cls(seq_group.request_id,
prompt,
prompt_token_ids,
prompt_logprobs,
outputs,
finished,
lora_request=seq_group.lora_request)
def __repr__(self) -> str:
return (f"RequestOutput(request_id={self.request_id}, "
......@@ -117,4 +129,5 @@ class RequestOutput:
f"prompt_token_ids={self.prompt_token_ids}, "
f"prompt_logprobs={self.prompt_logprobs}, "
f"outputs={self.outputs}, "
f"finished={self.finished})")
f"finished={self.finished}, "
f"lora_request={self.lora_request})")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment