Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
51679bbd
Commit
51679bbd
authored
Feb 01, 2024
by
zhuwenwen
Browse files
resolve merge confilcts
parents
4095d0db
1af090b5
Changes
170
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3313 additions
and
236 deletions
+3313
-236
vllm/model_executor/layers/rejection_sampler.py
vllm/model_executor/layers/rejection_sampler.py
+392
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+33
-32
vllm/model_executor/layers/triton_kernel/__init__.py
vllm/model_executor/layers/triton_kernel/__init__.py
+0
-0
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
+728
-0
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+20
-6
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+24
-8
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+8
-2
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+453
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+25
-6
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+27
-6
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+104
-96
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+412
-0
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+59
-61
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+336
-0
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+299
-0
vllm/model_executor/parallel_utils/communication_op.py
vllm/model_executor/parallel_utils/communication_op.py
+105
-13
vllm/model_executor/parallel_utils/custom_all_reduce.py
vllm/model_executor/parallel_utils/custom_all_reduce.py
+227
-0
vllm/model_executor/parallel_utils/parallel_state.py
vllm/model_executor/parallel_utils/parallel_state.py
+31
-2
vllm/model_executor/weight_utils.py
vllm/model_executor/weight_utils.py
+14
-1
vllm/outputs.py
vllm/outputs.py
+16
-3
No files found.
vllm/model_executor/layers/rejection_sampler.py
0 → 100644
View file @
51679bbd
from
typing
import
Tuple
,
Optional
from
functools
import
cached_property
import
torch
import
torch.nn
as
nn
import
torch.jit
class
RejectionSampler
(
nn
.
Module
):
"""Apply modified rejection sampling as described in "Accelerating Large
Language Model Decoding with Speculative Sampling"
https://arxiv.org/pdf/2302.01318.pdf.
"""
def
__init__
(
self
,
strict_mode
:
bool
=
False
):
"""Create a rejection sampler.
Args:
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super
().
__init__
()
self
.
probs_dtype
=
torch
.
float32
self
.
token_id_dtype
=
torch
.
int64
self
.
_strict_mode
=
strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
# accepted. There is always only one possible bonus token. We store this
# value in a variable for readability.
self
.
_num_bonus_tokens
=
1
self
.
num_accepted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_emitted_tokens
:
Optional
[
torch
.
Tensor
]
=
None
self
.
num_draft_tokens
:
int
=
0
def
init_gpu_tensors
(
self
,
rank
:
int
)
->
None
:
assert
self
.
num_accepted_tokens
is
None
device
=
f
"cuda:
{
rank
}
"
self
.
num_accepted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
self
.
num_emitted_tokens
=
torch
.
tensor
(
0
,
dtype
=
torch
.
long
,
device
=
device
)
def
forward
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""Sample token ids using rejection sampling. This accepts or rejects
tokens proposed by the draft model using the probability of each token
according to the draft and target models.
In the worst case where all draft tokens are rejected, it is guaranteed
one correct token will be emitted.
In the case where all draft tokens are accepted, a bonus token will be
accepted as its cheap to have the target model score this speculative
sequence.
Args:
target_probs: The probability distribution over token ids given
context according to the target model.
shape = [batch_size, num_speculative_tokens, vocab_size]
bonus_token_ids: The "bonus" token ids that are accepted iff all
speculative tokens in a sequence are accepted.
shape = [batch_size, num_bonus_tokens]
draft_probs: The probability distribution over token ids given
context according to the draft model.
shape = [batch_size, num_speculative_tokens, vocab_size]
draft_token_ids: The token ids that were sampled from the draft
probabilities.
shape = [batch_size, num_speculative_tokens]
Returns:
output_token_ids: The token ids sampled via rejection sampling,
or -1 if unable to sample a token because the previous token
was rejected.
shape = [batch_size, num_speculative_tokens + num_bonus_tokens]
"""
# Only perform shape/dtype/device checking in strict mode, as it adds
# overhead.
if
self
.
_strict_mode
:
self
.
_raise_if_incorrect_shape
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_incorrect_dtype
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_inconsistent_device
(
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
)
self
.
_raise_if_out_of_bounds_vocab
(
target_probs
.
shape
[
-
1
],
bonus_token_ids
,
draft_token_ids
)
accepted
,
recovered_token_ids
=
self
.
_batch_modified_rejection_sampling
(
target_probs
,
draft_probs
,
draft_token_ids
,
)
output_token_ids
=
self
.
_create_output
(
accepted
,
recovered_token_ids
,
draft_token_ids
,
bonus_token_ids
,
)
return
output_token_ids
def
_batch_modified_rejection_sampling
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Perform modified rejection sampling on each sequence.
Returns:
A tuple of two tensors:
0: A bool tensor of which tokens in each sequence is accepted.
shape = [batch_size, k]
1: Token ids sampled from a recovered distribution, to be used
when a token is rejected.
shape = [batch_size, k]
"""
batch_size
,
k
,
vocab_size
=
draft_probs
.
shape
# shape [batch_size, k]
accepted
=
self
.
_get_accepted
(
target_probs
,
draft_probs
,
draft_token_ids
)
recovered_probs
=
self
.
_get_recovered_probs
(
target_probs
,
draft_probs
).
reshape
(
batch_size
*
k
,
vocab_size
)
recovered_token_ids
=
_multinomial
(
recovered_probs
,
num_samples
=
1
).
reshape
(
batch_size
,
k
)
return
accepted
,
recovered_token_ids
def
_get_accepted
(
self
,
target_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [batch_size, k, vocab_size]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
)
->
torch
.
Tensor
:
r
"""Create bool matrix over the proposed draft tokens. If
True, then a token can be accepted, else it should be
rejected.
Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of
:math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according
to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the
same conditional probability according to the draft model, the token
is accepted with probability:
.. math::
\min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)}
{p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right)
This implementation does not apply causality. When using the output,
if a token is rejected, subsequent tokens should not be used.
Returns a bool tensor of shape [batch_size, k] specifying which tokens
are accepted.
"""
batch_size
,
k
,
_
=
draft_probs
.
shape
batch_indices
=
torch
.
arange
(
batch_size
,
device
=
target_probs
.
device
)[:,
None
]
probs_indicies
=
torch
.
arange
(
k
,
device
=
target_probs
.
device
)
# shape [batch_size, k]
selected_draft_probs
=
draft_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
# shape [batch_size, k]
selected_target_probs
=
target_probs
[
batch_indices
,
probs_indicies
,
draft_token_ids
]
uniform_rand
=
torch
.
rand
(
batch_size
,
k
,
dtype
=
self
.
probs_dtype
,
device
=
target_probs
.
device
)
capped_ratio
=
torch
.
minimum
(
selected_target_probs
/
selected_draft_probs
,
torch
.
full
((
1
,
),
1
,
device
=
target_probs
.
device
))
accepted
=
uniform_rand
<
capped_ratio
return
accepted
def
_get_recovered_probs
(
self
,
target_probs
:
torch
.
Tensor
,
# [k, vocab_size]
draft_probs
:
torch
.
Tensor
,
# [k, vocab_size]
)
->
torch
.
Tensor
:
r
"""Create a probability distribution for each proposed token which can
be sampled if the proposed token is rejected.
When this routine is applied sequentially, the true distribution of the
target model is recovered (within hardware numerics).
The probability distribution used in this rejection case is constructed
as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of
:math:`x` given context :math:`x_1, \dots, x_n` according to the target
model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability
according to the draft model:
.. math::
x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+
where :math:`(f(x))_+` is defined as:
.. math::
(f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))}
See https://github.com/vllm-project/vllm/pull/2336 for a visualization
of the draft, target, and recovered probability distributions.
Returns a tensor of shape [batch_size, k, vocab_size].
Note: This batches operations on GPU and thus constructs the recovered
distribution for all tokens, even if they are accepted. This causes
division-by-zero errors, so we use self._smallest_positive_value to
avoid that. This introduces some drift to the distribution.
"""
_
,
k
,
_
=
draft_probs
.
shape
# shape [batch_size, k, vocab_size]
difference
=
target_probs
-
draft_probs
# TODO(cade): Can we use logprobs instead of probs, and avoid the
# division-by-zero errors without introducing distribution drift?
# shape [batch_size, k, vocab_size]
f
=
torch
.
clamp
(
difference
,
min
=
self
.
_smallest_positive_value
)
# shape [batch_size, k, vocab_size]
recovered_probs
=
f
/
torch
.
sum
(
f
,
dim
=-
1
).
reshape
(
-
1
,
k
,
1
)
return
recovered_probs
@
cached_property
def
_smallest_positive_value
(
self
)
->
float
:
"""Return the smallest positive value representable by the probs dtype.
This value is used when constructing a distribution from which to sample
recovered tokens in the first rejection case.
See _get_recovered_probs for more details
Note that this isn't actually the smallest positive value representable
by float32, but the smallest positive normal value.
See https://en.wikipedia.org/wiki/Subnormal_number for more information.
"""
return
torch
.
finfo
(
self
.
probs_dtype
).
tiny
def
_create_output
(
self
,
accepted
:
torch
.
Tensor
,
# [batch_size, k]
recovered_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
draft_token_ids
:
torch
.
Tensor
,
# [batch_size, k]
bonus_token_ids
:
torch
.
Tensor
,
# [batch_size]
)
->
torch
.
Tensor
:
"""Format output. Returns a matrix of token ids. When
a token is rejected via rejection sampling, all subsequent
token ids are set to -1 for the sequence.
shape = [batch_size, k + num_bonus_tokens]
"""
bonus_token_ids
=
bonus_token_ids
.
squeeze
()
batch_size
,
k
=
recovered_token_ids
.
shape
# Determine the index of the first False value for each row.
limits
=
(
accepted
==
0
).
max
(
1
).
indices
limits
[
~
(
accepted
==
0
).
any
(
1
)]
=
k
# Create masks using the indices.
indices
=
torch
.
arange
(
k
,
device
=
accepted
.
device
).
unsqueeze
(
0
)
accepted_mask
=
indices
<
limits
.
unsqueeze
(
1
)
after_false_mask
=
indices
==
limits
.
unsqueeze
(
1
)
# Create an extended output tensor
output_with_bonus_tokens
=
-
torch
.
ones
(
(
batch_size
,
k
+
self
.
_num_bonus_tokens
),
dtype
=
self
.
token_id_dtype
,
device
=
accepted
.
device
)
output
=
output_with_bonus_tokens
[:,
:
k
]
# Fill in the first k columns of the output tensor using masks and data
# tensors.
output
[:,
:
k
]
=
torch
.
where
(
accepted_mask
,
draft_token_ids
,
-
torch
.
ones_like
(
draft_token_ids
))
# Fill the last column.
# We check output directly as accepted may have True values inconsistent
# with causal acceptance.
output_with_bonus_tokens
[:,
-
1
]
=
torch
.
where
(
output
[:,
-
1
]
!=
-
1
,
bonus_token_ids
,
-
1
)
# Fill the recovered token ids.
output
.
mul_
(
~
after_false_mask
).
add_
(
recovered_token_ids
.
mul
(
after_false_mask
))
self
.
num_accepted_tokens
+=
accepted
.
sum
()
self
.
num_emitted_tokens
+=
(
output_with_bonus_tokens
!=
-
1
).
sum
()
self
.
num_draft_tokens
+=
batch_size
*
k
return
output_with_bonus_tokens
def
_raise_if_incorrect_shape
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
(
target_batch_size
,
num_target_probs
,
target_vocab_size
)
=
target_probs
.
shape
bonus_batch_size
,
num_bonus_tokens
=
bonus_token_ids
.
shape
draft_batch_size
,
num_draft_probs
,
draft_vocab_size
=
draft_probs
.
shape
draft_token_ids_batch_size
,
num_draft_token_ids
=
draft_token_ids
.
shape
assert
draft_batch_size
==
target_batch_size
assert
num_draft_probs
==
num_target_probs
assert
(
draft_vocab_size
==
target_vocab_size
),
f
"
{
draft_vocab_size
=
}
{
target_vocab_size
=
}
"
assert
draft_token_ids_batch_size
==
draft_batch_size
assert
num_draft_token_ids
==
num_draft_probs
assert
bonus_batch_size
==
target_batch_size
assert
num_bonus_tokens
==
self
.
_num_bonus_tokens
def
_raise_if_incorrect_dtype
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
all
(
probs
.
dtype
==
self
.
probs_dtype
for
probs
in
[
target_probs
,
draft_probs
])
assert
all
(
token_ids
.
dtype
==
self
.
token_id_dtype
for
token_ids
in
[
bonus_token_ids
,
draft_token_ids
])
def
_raise_if_inconsistent_device
(
self
,
target_probs
:
torch
.
Tensor
,
bonus_token_ids
:
torch
.
Tensor
,
draft_probs
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
devices
=
[
t
.
device
for
t
in
[
target_probs
,
bonus_token_ids
,
draft_probs
,
draft_token_ids
]
]
assert
all
([
devices
[
0
]
==
device
for
device
in
devices
])
def
_raise_if_out_of_bounds_vocab
(
self
,
vocab_size
:
int
,
bonus_token_ids
:
torch
.
Tensor
,
draft_token_ids
:
torch
.
Tensor
,
)
->
None
:
assert
torch
.
all
(
bonus_token_ids
<
vocab_size
)
assert
torch
.
all
(
bonus_token_ids
>=
0
)
assert
torch
.
all
(
draft_token_ids
<
vocab_size
)
assert
torch
.
all
(
draft_token_ids
>=
0
)
# torch.multinomial forces a GPU<->CPU sync.
# Therefore, we use an optimized implementation instead that skips the sync.
# Note that we always sample with replacement.
# probs will be modified in place, but this is fine, as we pass
# in a copy already.
@
torch
.
jit
.
script
def
_multinomial
(
probs
:
torch
.
Tensor
,
num_samples
:
int
,
)
->
torch
.
Tensor
:
if
num_samples
>
1
:
# This is equivalent to torch.repeat_interleaved (which also
# forces a GPU<->CPU sync).
probs
=
probs
[:,
None
,
:].
expand
(
probs
.
shape
[
0
],
num_samples
,
probs
.
shape
[
1
]).
contiguous
().
view
(
-
1
,
probs
.
shape
[
1
])
q
=
torch
.
empty_like
(
probs
).
exponential_
(
1.0
)
return
probs
.
div_
(
q
).
argmax
(
dim
=
1
).
view
(
-
1
,
num_samples
)
vllm/model_executor/layers/sampler.py
View file @
51679bbd
...
...
@@ -27,9 +27,25 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
def
__init__
(
self
,
vocab_size
:
int
)
->
None
:
def
__init__
(
self
,
vocab_size
:
int
,
org_vocab_size
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
vocab_size
=
vocab_size
# original vocabulary size (without LoRA).
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
embedding_bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
tensor_model_parallel_gather
(
logits
)
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
logits
=
logits
[:,
:
self
.
org_vocab_size
]
return
logits
def
forward
(
self
,
...
...
@@ -42,8 +58,7 @@ class Sampler(nn.Module):
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
# Get the logits for the next tokens.
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
self
.
vocab_size
)
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
...
...
@@ -76,7 +91,7 @@ class Sampler(nn.Module):
logits
.
div_
(
sampling_tensors
.
temperatures
.
unsqueeze_
(
dim
=
1
))
if
do_top_p_top_k
:
logits
=
_apply_top_
p
_top_
k
(
logits
,
sampling_tensors
.
top_ps
,
logits
=
_apply_top_
k
_top_
p
(
logits
,
sampling_tensors
.
top_ps
,
sampling_tensors
.
top_ks
)
if
do_min_p
:
...
...
@@ -98,20 +113,6 @@ class Sampler(nn.Module):
prompt_logprobs
,
sample_logprobs
)
def
_get_logits
(
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
embedding_bias
:
Optional
[
torch
.
Tensor
],
vocab_size
:
int
)
->
Optional
[
torch
.
Tensor
]:
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
tensor_model_parallel_gather
(
logits
)
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
logits
=
logits
[:,
:
vocab_size
]
return
logits
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
...
...
@@ -185,27 +186,27 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
return
logits
def
_apply_top_
p
_top_
k
(
def
_apply_top_
k
_top_
p
(
logits
:
torch
.
Tensor
,
p
:
torch
.
Tensor
,
k
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
True
)
logits_sort
,
logits_idx
=
logits
.
sort
(
dim
=-
1
,
descending
=
False
)
# Apply top-k.
top_k_mask
=
logits_sort
.
size
(
1
)
-
k
.
to
(
torch
.
long
)
# Get all the top_k values.
top_k_mask
=
logits_sort
.
gather
(
1
,
top_k_mask
.
unsqueeze
(
dim
=
1
))
top_k_mask
=
logits_sort
<
top_k_mask
logits_sort
.
masked_fill_
(
top_k_mask
,
-
float
(
"inf"
))
# Apply top-p.
probs_sort
=
logits_sort
.
softmax
(
dim
=-
1
)
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
).
sub_
(
probs_sort
)
top_p_mask
=
probs_sum
>
p
.
unsqueeze_
(
dim
=
1
)
# Apply top-k.
# Create a mask for the top-k elements.
top_k_mask
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
device
=
logits_idx
.
device
)
top_k_mask
=
top_k_mask
.
expand
(
logits_idx
.
shape
[
0
],
-
1
)
top_k_mask
=
top_k_mask
>=
k
.
unsqueeze_
(
dim
=
1
)
# Final mask.
mask
=
(
top_p_mask
|
top_k_mask
)
logits_sort
.
masked_fill_
(
mask
,
-
float
(
"inf"
))
probs_sum
=
probs_sort
.
cumsum
(
dim
=-
1
)
top_p_mask
=
probs_sum
<=
1
-
p
.
unsqueeze
(
dim
=
1
)
# at least one
top_p_mask
[:,
-
1
]
=
False
logits_sort
.
masked_fill_
(
top_p_mask
,
-
float
(
"inf"
))
# Re-sort the probabilities.
src
=
torch
.
arange
(
logits_idx
.
shape
[
-
1
],
...
...
vllm/model_executor/layers/triton_kernel/__init__.py
0 → 100644
View file @
51679bbd
vllm/model_executor/layers/triton_kernel/prefix_prefill.py
0 → 100644
View file @
51679bbd
# The kernels in this file are adapted from LightLLM's context_attention_fwd:
# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py
import
torch
import
triton
import
triton.language
as
tl
if
triton
.
__version__
>=
"2.1.0"
:
@
triton
.
jit
def
_fwd_kernel
(
Q
,
K
,
V
,
K_cache
,
V_cache
,
B_Loc
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
B_Ctxlen
,
block_size
,
x
,
Out
,
stride_b_loc_b
,
stride_b_loc_s
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_obs
,
stride_oh
,
stride_od
,
stride_k_cache_bs
,
stride_k_cache_h
,
stride_k_cache_d
,
stride_k_cache_bl
,
stride_k_cache_x
,
stride_v_cache_bs
,
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
float
(
"-inf"
))
qk
*=
sm_scale
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# # update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
1
,
0
)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
p
=
tl
.
exp
(
qk
-
m_ij
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
alpha
=
tl
.
exp
(
m_i
-
m_i_new
)
beta
=
tl
.
exp
(
m_ij
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
beta
*
l_ij
# -- update output accumulator --
# scale p
p_scale
=
beta
/
l_i_new
p
=
p
*
p_scale
[:,
None
]
# scale acc
acc_scale
=
l_i
/
l_i_new
*
alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
return
@
triton
.
jit
def
_fwd_kernel_flash_attn_v2
(
Q
,
K
,
V
,
K_cache
,
V_cache
,
B_Loc
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
B_Ctxlen
,
block_size
,
x
,
Out
,
stride_b_loc_b
,
stride_b_loc_s
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_obs
,
stride_oh
,
stride_od
,
stride_k_cache_bs
,
stride_k_cache_h
,
stride_k_cache_d
,
stride_k_cache_bl
,
stride_k_cache_x
,
stride_v_cache_bs
,
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
float
(
"-inf"
))
qk
*=
sm_scale
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
p
=
tl
.
math
.
exp
(
qk
-
m_i_new
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
alpha
=
tl
.
math
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale
=
alpha
# acc_scale = l_i / l_i_new * alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
1
,
0
)
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
*=
sm_scale
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
p
=
tl
.
math
.
exp
(
qk
-
m_i_new
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
alpha
=
tl
.
math
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale
=
alpha
# acc_scale = l_i / l_i_new * alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
# acc /= l_i[:, None]
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
return
@
triton
.
jit
def
_fwd_kernel_alibi
(
Q
,
K
,
V
,
K_cache
,
V_cache
,
B_Loc
,
sm_scale
,
B_Start_Loc
,
B_Seqlen
,
B_Ctxlen
,
Alibi_slopes
,
block_size
,
x
,
Out
,
stride_b_loc_b
,
stride_b_loc_s
,
stride_qbs
,
stride_qh
,
stride_qd
,
stride_kbs
,
stride_kh
,
stride_kd
,
stride_vbs
,
stride_vh
,
stride_vd
,
stride_obs
,
stride_oh
,
stride_od
,
stride_k_cache_bs
,
stride_k_cache_h
,
stride_k_cache_d
,
stride_k_cache_bl
,
stride_k_cache_x
,
stride_v_cache_bs
,
stride_v_cache_h
,
stride_v_cache_d
,
stride_v_cache_bl
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
):
# attn_bias[]
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
start_m
=
tl
.
program_id
(
2
)
# cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0
cur_batch_ctx_len
=
tl
.
load
(
B_Ctxlen
+
cur_batch
)
cur_batch_seq_len
=
tl
.
load
(
B_Seqlen
+
cur_batch
)
cur_batch_in_all_start_index
=
tl
.
load
(
B_Start_Loc
+
cur_batch
)
block_start_loc
=
BLOCK_M
*
start_m
# initialize offsets
offs_n
=
tl
.
arange
(
0
,
BLOCK_N
)
offs_d
=
tl
.
arange
(
0
,
BLOCK_DMODEL
)
offs_m
=
start_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
off_q
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
[
None
,
:]
*
stride_qd
)
q
=
tl
.
load
(
Q
+
off_q
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
# # initialize pointer to m and l
m_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
-
float
(
"inf"
)
l_i
=
tl
.
zeros
([
BLOCK_M
],
dtype
=
tl
.
float32
)
acc
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_DMODEL
],
dtype
=
tl
.
float32
)
alibi_slope
=
tl
.
load
(
Alibi_slopes
+
cur_head
)
alibi_start_q
=
tl
.
arange
(
0
,
BLOCK_M
)
+
block_start_loc
+
cur_batch_ctx_len
alibi_start_k
=
0
for
start_n
in
range
(
0
,
cur_batch_ctx_len
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
bn
=
tl
.
load
(
B_Loc
+
cur_batch
*
stride_b_loc_b
+
((
start_n
+
offs_n
)
//
block_size
)
*
stride_b_loc_s
,
mask
=
(
start_n
+
offs_n
)
<
cur_batch_ctx_len
,
other
=
0
)
off_k
=
(
bn
[
None
,
:]
*
stride_k_cache_bs
+
cur_head
*
stride_k_cache_h
+
(
offs_d
[:,
None
]
//
x
)
*
stride_k_cache_d
+
((
start_n
+
offs_n
[
None
,
:])
%
block_size
)
*
stride_k_cache_bl
+
(
offs_d
[:,
None
]
%
x
)
*
stride_k_cache_x
)
off_v
=
(
bn
[:,
None
]
*
stride_v_cache_bs
+
cur_head
*
stride_v_cache_h
+
offs_d
[
None
,
:]
*
stride_v_cache_d
+
(
start_n
+
offs_n
[:,
None
])
%
block_size
*
stride_v_cache_bl
)
k
=
tl
.
load
(
K_cache
+
off_k
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
)
qk
=
tl
.
where
((
start_n
+
offs_n
[
None
,
:])
<
cur_batch_ctx_len
,
qk
,
float
(
"-inf"
))
qk
*=
sm_scale
# load alibi
alibi
=
(
tl
.
arange
(
0
,
BLOCK_N
)[
None
,
:]
+
alibi_start_k
-
alibi_start_q
[:,
None
])
*
alibi_slope
alibi
=
tl
.
where
(
(
alibi
<=
0
)
&
(
alibi_start_q
[:,
None
]
<
cur_batch_seq_len
),
alibi
,
float
(
"-inf"
))
qk
+=
alibi
alibi_start_k
+=
BLOCK_N
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
p
=
tl
.
math
.
exp
(
qk
-
m_i_new
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
alpha
=
tl
.
math
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale
=
alpha
# acc_scale = l_i / l_i_new * alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
V_cache
+
off_v
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_ctx_len
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
,
allow_tf32
=
False
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
off_k
=
(
offs_n
[
None
,
:]
*
stride_kbs
+
cur_head
*
stride_kh
+
offs_d
[:,
None
]
*
stride_kd
)
off_v
=
(
offs_n
[:,
None
]
*
stride_vbs
+
cur_head
*
stride_vh
+
offs_d
[
None
,
:]
*
stride_vd
)
k_ptrs
=
K
+
off_k
v_ptrs
=
V
+
off_v
block_mask
=
tl
.
where
(
block_start_loc
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
1
,
0
)
# init alibi
alibi_slope
=
tl
.
load
(
Alibi_slopes
+
cur_head
)
alibi_start_q
=
tl
.
arange
(
0
,
BLOCK_M
)
+
block_start_loc
+
cur_batch_ctx_len
alibi_start_k
=
cur_batch_ctx_len
# # init debuger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
for
start_n
in
range
(
0
,
block_mask
*
(
start_m
+
1
)
*
BLOCK_M
,
BLOCK_N
):
start_n
=
tl
.
multiple_of
(
start_n
,
BLOCK_N
)
# -- compute qk ----
k
=
tl
.
load
(
k_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_kbs
,
mask
=
(
start_n
+
offs_n
[
None
,
:])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
qk
=
tl
.
zeros
([
BLOCK_M
,
BLOCK_N
],
dtype
=
tl
.
float32
)
qk
+=
tl
.
dot
(
q
,
k
,
allow_tf32
=
False
)
qk
*=
sm_scale
qk
=
tl
.
where
(
offs_m
[:,
None
]
>=
(
start_n
+
offs_n
[
None
,
:]),
qk
,
float
(
"-inf"
))
# load alibi
alibi
=
(
tl
.
arange
(
0
,
BLOCK_N
)[
None
,
:]
+
alibi_start_k
-
alibi_start_q
[:,
None
])
*
alibi_slope
alibi
=
tl
.
where
(
(
alibi
<=
0
)
&
(
alibi_start_q
[:,
None
]
<
cur_batch_seq_len
),
alibi
,
float
(
"-inf"
))
qk
+=
alibi
alibi_start_k
+=
BLOCK_N
# -- compute m_ij, p, l_ij
m_ij
=
tl
.
max
(
qk
,
1
)
m_i_new
=
tl
.
maximum
(
m_i
,
m_ij
)
p
=
tl
.
math
.
exp
(
qk
-
m_i_new
[:,
None
])
l_ij
=
tl
.
sum
(
p
,
1
)
# -- update m_i and l_i
alpha
=
tl
.
math
.
exp
(
m_i
-
m_i_new
)
l_i_new
=
alpha
*
l_i
+
l_ij
# -- update output accumulator --
# scale p
# scale acc
acc_scale
=
alpha
# acc_scale = l_i / l_i_new * alpha
acc
=
acc
*
acc_scale
[:,
None
]
# update acc
v
=
tl
.
load
(
v_ptrs
+
(
cur_batch_in_all_start_index
+
start_n
)
*
stride_vbs
,
mask
=
(
start_n
+
offs_n
[:,
None
])
<
cur_batch_seq_len
-
cur_batch_ctx_len
,
other
=
0.0
)
p
=
p
.
to
(
v
.
dtype
)
acc
+=
tl
.
dot
(
p
,
v
,
allow_tf32
=
False
)
# update m_i and l_i
l_i
=
l_i_new
m_i
=
m_i_new
acc
=
acc
/
l_i
[:,
None
]
# initialize pointers to output
off_o
=
(
(
cur_batch_in_all_start_index
+
offs_m
[:,
None
])
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
[
None
,
:]
*
stride_od
)
out_ptrs
=
Out
+
off_o
tl
.
store
(
out_ptrs
,
acc
,
mask
=
offs_m
[:,
None
]
<
cur_batch_seq_len
-
cur_batch_ctx_len
)
return
@
torch
.
inference_mode
()
def
context_attention_fwd
(
q
,
k
,
v
,
o
,
k_cache
,
v_cache
,
b_loc
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
max_input_len
,
alibi_slopes
=
None
):
BLOCK
=
128
# shape constraints
Lq
,
Lk
,
Lv
=
q
.
shape
[
-
1
],
k
.
shape
[
-
1
],
v
.
shape
[
-
1
]
assert
Lq
==
Lk
and
Lk
==
Lv
assert
Lk
in
{
16
,
32
,
64
,
128
}
sm_scale
=
1.0
/
(
Lq
**
0.5
)
batch
,
head
=
b_seq_len
.
shape
[
0
],
q
.
shape
[
1
]
grid
=
(
batch
,
head
,
triton
.
cdiv
(
max_input_len
,
BLOCK
))
# batch, head,
num_warps
=
8
if
Lk
<=
64
else
8
if
alibi_slopes
is
not
None
:
_fwd_kernel_alibi
[
grid
](
q
,
k
,
v
,
k_cache
,
v_cache
,
b_loc
,
sm_scale
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
alibi_slopes
,
v_cache
.
shape
[
3
],
8
,
o
,
b_loc
.
stride
(
0
),
b_loc
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
k_cache
.
stride
(
0
),
k_cache
.
stride
(
1
),
k_cache
.
stride
(
2
),
k_cache
.
stride
(
3
),
k_cache
.
stride
(
4
),
#[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache
.
stride
(
0
),
v_cache
.
stride
(
1
),
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
_fwd_kernel
[
grid
](
q
,
k
,
v
,
k_cache
,
v_cache
,
b_loc
,
sm_scale
,
b_start_loc
,
b_seq_len
,
b_ctx_len
,
v_cache
.
shape
[
3
],
8
,
o
,
b_loc
.
stride
(
0
),
b_loc
.
stride
(
1
),
q
.
stride
(
0
),
q
.
stride
(
1
),
q
.
stride
(
2
),
k
.
stride
(
0
),
k
.
stride
(
1
),
k
.
stride
(
2
),
v
.
stride
(
0
),
v
.
stride
(
1
),
v
.
stride
(
2
),
o
.
stride
(
0
),
o
.
stride
(
1
),
o
.
stride
(
2
),
k_cache
.
stride
(
0
),
k_cache
.
stride
(
1
),
k_cache
.
stride
(
2
),
k_cache
.
stride
(
3
),
k_cache
.
stride
(
4
),
#[num_blocks, num_kv_heads, head_size/x, block_size, x]
v_cache
.
stride
(
0
),
v_cache
.
stride
(
1
),
v_cache
.
stride
(
2
),
v_cache
.
stride
(
3
),
#[num_blocks, num_kv_heads, head_size, block_size]
BLOCK_M
=
BLOCK
,
BLOCK_DMODEL
=
Lk
,
BLOCK_N
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
1
,
)
return
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
51679bbd
...
...
@@ -13,8 +13,11 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.utils
import
set_weight_attrs
DEFAULT_VOCAB_PADDING_SIZE
=
64
def
pad_vocab_size
(
vocab_size
:
int
,
pad_to
:
int
=
64
)
->
int
:
def
pad_vocab_size
(
vocab_size
:
int
,
pad_to
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
)
->
int
:
"""Pad the vocab size to the given value."""
return
((
vocab_size
+
pad_to
-
1
)
//
pad_to
)
*
pad_to
...
...
@@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module):
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
):
super
().
__init__
()
# Keep the input dimensions.
self
.
num_embeddings
=
num_embeddings
self
.
num_embeddings_padded
=
pad_vocab_size
(
num_embeddings
)
self
.
org_vocab_size
=
org_num_embeddings
or
num_embeddings
self
.
num_embeddings_padded
=
pad_vocab_size
(
num_embeddings
,
padding_size
)
self
.
embedding_dim
=
embedding_dim
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
...
...
@@ -77,7 +86,7 @@ class VocabParallelEmbedding(torch.nn.Module):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
parallel_dim
=
param
.
parallel_dim
assert
loaded_weight
.
shape
[
parallel_dim
]
==
self
.
num_embeddings
assert
loaded_weight
.
shape
[
parallel_dim
]
==
self
.
org_vocab_size
loaded_weight
=
loaded_weight
[
self
.
vocab_start_index
:
self
.
vocab_end_index
]
param
[:
loaded_weight
.
shape
[
0
]].
data
.
copy_
(
loaded_weight
)
...
...
@@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding):
embedding_dim: size of hidden state.
bias: whether to use bias.
params_dtype: type of the parameters.
org_num_embeddings: original vocabulary size (without LoRA).
padding_size: padding size for the vocabulary.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
bias
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
)
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
org_num_embeddings
:
Optional
[
int
]
=
None
,
padding_size
:
int
=
DEFAULT_VOCAB_PADDING_SIZE
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
params_dtype
,
org_num_embeddings
,
padding_size
)
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
num_embeddings_per_partition
,
...
...
vllm/model_executor/model_loader.py
View file @
51679bbd
"""Utilities for selecting and loading models."""
import
contextlib
from
typing
import
Type
from
typing
import
Optional
,
Type
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
ModelConfig
from
vllm.config
import
ModelConfig
,
LoRAConfig
from
vllm.model_executor.models
import
ModelRegistry
from
vllm.model_executor.weight_utils
import
(
get_quant_config
,
initialize_dummy_weights
)
...
...
@@ -21,8 +20,14 @@ def _set_default_torch_dtype(dtype: torch.dtype):
torch
.
set_default_dtype
(
old_dtype
)
def
_get_model_architecture
(
config
:
PretrainedConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
config
,
"architectures"
,
[])
def
_get_model_architecture
(
model_config
:
ModelConfig
)
->
Type
[
nn
.
Module
]:
architectures
=
getattr
(
model_config
.
hf_config
,
"architectures"
,
[])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
if
(
model_config
.
quantization
is
not
None
and
"MixtralForCausalLM"
in
architectures
):
architectures
=
[
"QuantMixtralForCausalLM"
]
for
arch
in
architectures
:
model_cls
=
ModelRegistry
.
load_model_cls
(
arch
)
if
model_cls
is
not
None
:
...
...
@@ -32,8 +37,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
f
"Supported architectures:
{
ModelRegistry
.
get_supported_archs
()
}
"
)
def
get_model
(
model_config
:
ModelConfig
)
->
nn
.
Module
:
model_class
=
_get_model_architecture
(
model_config
.
hf_config
)
def
get_model
(
model_config
:
ModelConfig
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
)
->
nn
.
Module
:
model_class
=
_get_model_architecture
(
model_config
)
# Get the (maybe quantized) linear method.
linear_method
=
None
...
...
@@ -62,7 +68,17 @@ def get_model(model_config: ModelConfig) -> nn.Module:
# Create a model instance.
# The weights will be initialized as empty tensors.
with
torch
.
device
(
"cuda"
):
model
=
model_class
(
model_config
.
hf_config
,
linear_method
)
if
getattr
(
model_class
,
"supports_lora"
,
False
):
model
=
model_class
(
model_config
.
hf_config
,
linear_method
,
lora_config
)
elif
lora_config
:
raise
ValueError
(
f
"Model
{
model_class
.
__name__
}
does not support LoRA, "
"but LoRA is enabled. Support for this model may "
"be added in the future. If this is important to you, "
"please open an issue on github."
)
else
:
model
=
model_class
(
model_config
.
hf_config
,
linear_method
)
if
model_config
.
load_format
==
"dummy"
:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
...
...
vllm/model_executor/models/__init__.py
View file @
51679bbd
...
...
@@ -18,6 +18,7 @@ _MODELS = {
"ChatGLMModel"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"ChatGLMForConditionalGeneration"
:
(
"chatglm"
,
"ChatGLMForCausalLM"
),
"DeciLMForCausalLM"
:
(
"decilm"
,
"DeciLMForCausalLM"
),
"DeepseekForCausalLM"
:
(
"deepseek"
,
"DeepseekForCausalLM"
),
"FalconForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"GPT2LMHeadModel"
:
(
"gpt2"
,
"GPT2LMHeadModel"
),
"GPTBigCodeForCausalLM"
:
(
"gpt_bigcode"
,
"GPTBigCodeForCausalLM"
),
...
...
@@ -29,14 +30,17 @@ _MODELS = {
"LLaMAForCausalLM"
:
(
"llama"
,
"LlamaForCausalLM"
),
"MistralForCausalLM"
:
(
"mistral"
,
"MistralForCausalLM"
),
"MixtralForCausalLM"
:
(
"mixtral"
,
"MixtralForCausalLM"
),
"QuantMixtralForCausalLM"
:
(
"mixtral_quant"
,
"MixtralForCausalLM"
),
# transformers's mpt class has lower case
"MptForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"MPTForCausalLM"
:
(
"mpt"
,
"MPTForCausalLM"
),
"OPTForCausalLM"
:
(
"opt"
,
"OPTForCausalLM"
),
"PhiForCausalLM"
:
(
"phi
_1_5
"
,
"PhiForCausalLM"
),
"PhiForCausalLM"
:
(
"phi"
,
"PhiForCausalLM"
),
"QWenLMHeadModel"
:
(
"qwen"
,
"QWenLMHeadModel"
),
"Qwen2ForCausalLM"
:
(
"qwen2"
,
"Qwen2ForCausalLM"
),
"RWForCausalLM"
:
(
"falcon"
,
"FalconForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
),
"StableLMEpochForCausalLM"
:
(
"stablelm"
,
"StablelmForCausalLM"
),
"YiForCausalLM"
:
(
"yi"
,
"YiForCausalLM"
)
}
# Models not supported by ROCm.
...
...
@@ -45,6 +49,8 @@ _ROCM_UNSUPPORTED_MODELS = []
# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_PARTIALLY_SUPPORTED_MODELS
=
{
"Qwen2ForCausalLM"
:
"Sliding window attention is not yet supported in ROCm's flash attention"
,
"MistralForCausalLM"
:
"Sliding window attention is not yet supported in ROCm's flash attention"
,
"MixtralForCausalLM"
:
...
...
vllm/model_executor/models/deepseek.py
0 → 100644
View file @
51679bbd
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Deepseek model."""
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
ReplicatedLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
DeepseekMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
reduce_results
:
bool
=
True
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
reduce_results
=
reduce_results
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
DeepseekMoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
n_routed_experts
=
config
.
n_routed_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
n_routed_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
n_routed_experts
}
."
)
self
.
experts
=
nn
.
ModuleList
([
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
moe_intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
reduce_results
=
False
)
for
idx
in
range
(
self
.
n_routed_experts
)
])
self
.
pack_params
()
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
n_routed_experts
,
bias
=
False
,
linear_method
=
None
)
if
config
.
n_shared_experts
is
not
None
:
intermediate_size
=
config
.
moe_intermediate_size
*
config
.
n_shared_experts
self
.
shared_experts
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
reduce_results
=
False
,
)
def
pack_params
(
self
):
w1
=
[]
w2
=
[]
for
expert
in
self
.
experts
:
w1
.
append
(
expert
.
gate_up_proj
.
weight
)
w2
.
append
(
expert
.
down_proj
.
weight
)
self
.
w1
=
torch
.
_utils
.
_flatten_dense_tensors
(
w1
)
w1s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w1
,
w1
)
for
data
,
param
in
zip
(
w1s
,
w1
):
param
.
data
=
data
self
.
w1
=
self
.
w1
.
view
(
len
(
w1
),
*
w1s
[
0
].
shape
)
self
.
w2
=
torch
.
_utils
.
_flatten_dense_tensors
(
w2
)
w2s
=
torch
.
_utils
.
_unflatten_dense_tensors
(
self
.
w2
,
w2
)
for
data
,
param
in
zip
(
w2s
,
w2
):
param
.
data
=
data
self
.
w2
=
self
.
w2
.
view
(
len
(
w2
),
*
w2s
[
0
].
shape
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
if
self
.
config
.
n_shared_experts
is
not
None
:
shared_output
=
self
.
shared_experts
(
hidden_states
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
if
self
.
config
.
norm_topk_prob
:
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
w1
,
self
.
w2
,
routing_weights
,
selected_experts
,
inplace
=
True
)
if
self
.
config
.
n_shared_experts
is
not
None
:
final_hidden_states
=
final_hidden_states
+
shared_output
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_dim
)
class
DeepseekAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
rope_theta
:
float
=
10000
,
rope_scaling
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
max_position_embeddings
:
int
=
8192
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
max_position_embeddings
=
max_position_embeddings
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position_embeddings
,
base
=
rope_theta
,
rope_scaling
=
rope_scaling
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
DeepseekDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_scaling
=
getattr
(
config
,
"rope_scaling"
,
None
)
max_position_embeddings
=
getattr
(
config
,
"max_position_embeddings"
,
8192
)
self
.
self_attn
=
DeepseekAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
rope_scaling
=
rope_scaling
,
max_position_embeddings
=
max_position_embeddings
,
linear_method
=
linear_method
,
)
if
(
config
.
n_routed_experts
is
not
None
and
\
layer_idx
>=
config
.
first_k_dense_replace
and
layer_idx
%
config
.
moe_layer_freq
==
0
):
self
.
mlp
=
DeepseekMoE
(
config
=
config
,
linear_method
=
linear_method
)
else
:
self
.
mlp
=
DeepseekMLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
DeepseekModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
DeepseekDecoderLayer
(
config
,
layer_idx
,
linear_method
=
linear_method
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
DeepseekForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
DeepseekModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
((
"mlp.experts."
in
name
or
"mlp.shared_experts."
in
name
)
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/llama.py
View file @
51679bbd
...
...
@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -225,14 +226,19 @@ class LlamaModel(nn.Module):
self
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
,
linear_method
)
...
...
@@ -263,18 +269,31 @@ class LlamaModel(nn.Module):
class
LlamaForCausalLM
(
nn
.
Module
):
supports_lora
=
True
def
__init__
(
self
,
config
:
LlamaConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
LlamaModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
model
=
LlamaModel
(
config
,
linear_method
,
lora_config
=
lora_config
)
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
sampler
=
Sampler
(
unpadded_vocab_size
,
config
.
vocab_size
)
def
forward
(
self
,
...
...
vllm/model_executor/models/mistral.py
View file @
51679bbd
...
...
@@ -38,13 +38,14 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.config
import
LoRAConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
...
...
@@ -220,15 +221,20 @@ class MistralModel(nn.Module):
self
,
config
:
MistralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
lora_vocab
=
(
lora_config
.
lora_extra_vocab_size
*
(
lora_config
.
max_loras
or
1
))
if
lora_config
else
0
self
.
vocab_size
=
config
.
vocab_size
+
lora_vocab
self
.
org_vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MistralDecoderLayer
(
config
,
linear_method
)
...
...
@@ -259,18 +265,33 @@ class MistralModel(nn.Module):
class
MistralForCausalLM
(
nn
.
Module
):
supports_lora
=
True
def
__init__
(
self
,
config
:
MistralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
lora_config
:
Optional
[
LoRAConfig
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
MistralModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
model
=
MistralModel
(
config
,
linear_method
,
lora_config
=
lora_config
)
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
unpadded_vocab_size
+=
lora_config
.
lora_extra_vocab_size
self
.
lm_head
=
ParallelLMHead
(
unpadded_vocab_size
,
config
.
hidden_size
,
org_num_embeddings
=
config
.
vocab_size
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
# We need bigger padding if using lora for kernel
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
self
.
sampler
=
Sampler
(
unpadded_vocab_size
,
config
.
vocab_size
)
def
forward
(
self
,
...
...
vllm/model_executor/models/mixtral.py
View file @
51679bbd
...
...
@@ -23,8 +23,6 @@
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
...
...
@@ -33,10 +31,11 @@ from transformers import MixtralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.fused_moe
import
fused_moe
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ReplicatedLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
...
...
@@ -47,6 +46,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
...
...
@@ -54,85 +54,77 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMLP
(
nn
.
Module
):
class
MixtralMoE
(
nn
.
Module
):
"""A tensor-parallel MoE implementation for Mixtral that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def
__init__
(
self
,
num_experts
:
int
,
top_k
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBas
e
]
=
None
,
)
->
None
:
params_dtype
:
Optional
[
torch
.
dtyp
e
]
=
None
,
):
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
num_experts
self
.
top_k
=
top_k
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
//
tp_size
class
MixtralMoE
(
nn
.
Module
):
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
([
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
gate
=
ReplicatedLinear
(
self
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
params_dtype
=
self
.
params_dtype
,
linear_method
=
None
)
self
.
ws
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
2
*
self
.
intermediate_size
,
self
.
hidden_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
self
.
w2s
=
nn
.
Parameter
(
torch
.
empty
(
self
.
num_total_experts
,
self
.
hidden_size
,
self
.
intermediate_size
,
device
=
"cuda"
,
dtype
=
self
.
params_dtype
))
set_weight_attrs
(
self
.
ws
,
{
"weight_loader"
:
self
.
weight_loader
,
})
set_weight_attrs
(
self
.
w2s
,
{
"weight_loader"
:
self
.
weight_loader
,
})
def
weight_loader
(
self
,
param
:
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
expert_id
:
int
):
tp_rank
=
get_tensor_model_parallel_rank
()
param_data
=
param
.
data
shard_size
=
self
.
intermediate_size
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
if
weight_name
.
endswith
(
"w1.weight"
):
param_data
[
expert_id
,
0
:
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w3.weight"
):
param_data
[
expert_id
,
shard_size
:
2
*
shard_size
,
:]
=
loaded_weight
[
shard
,
:]
if
weight_name
.
endswith
(
"w2.weight"
):
param_data
[
expert_id
,
:,
:]
=
loaded_weight
[:,
shard
]
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_
dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_
dim
)
batch_size
,
sequence_length
,
hidden_
size
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
self
.
hidden_
size
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
...
...
@@ -142,22 +134,18 @@ class MixtralMoE(nn.Module):
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
None
for
expert_idx
in
self
.
expert_indicies
:
expert_layer
=
self
.
experts
[
expert_idx
]
expert_mask
=
(
selected_experts
==
expert_idx
)
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
expert_weights
)
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
final_hidden_states
=
fused_moe
(
hidden_states
,
self
.
ws
,
self
.
w2s
,
routing_weights
,
selected_experts
,
inplace
=
True
)
final_hidden_states
=
tensor_model_parallel_all_reduce
(
final_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
)
.
view
(
batch_size
,
sequence_length
,
hidden_
dim
)
return
final_hidden_states
.
view
(
batch_size
,
sequence_length
,
hidden_
size
)
class
MixtralAttention
(
nn
.
Module
):
...
...
@@ -257,8 +245,11 @@ class MixtralDecoderLayer(nn.Module):
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
linear_method
=
linear_method
)
self
.
block_sparse_moe
=
MixtralMoE
(
num_experts
=
config
.
num_local_experts
,
top_k
=
config
.
num_experts_per_tok
,
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
...
...
@@ -378,6 +369,14 @@ class MixtralForCausalLM(nn.Module):
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
expert_params_mapping
=
[
# (param_name, weight_name, expert_id)
(
"ws"
if
weight_name
in
[
"w1"
,
"w3"
]
else
"w2s"
,
f
"experts.
{
expert_id
}
.
{
weight_name
}
.weight"
,
expert_id
)
for
expert_id
in
range
(
self
.
config
.
num_local_experts
)
for
weight_name
in
[
"w1"
,
"w2"
,
"w3"
]
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
...
...
@@ -387,6 +386,7 @@ class MixtralForCausalLM(nn.Module):
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
...
...
@@ -399,14 +399,22 @@ class MixtralForCausalLM(nn.Module):
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
(
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
param_name
,
weight_name
,
expert_id
in
expert_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
weight_name
,
expert_id
=
expert_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/mixtral_quant.py
0 → 100644
View file @
51679bbd
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Mixtral model."""
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
from
transformers
import
MixtralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
ReplicatedLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
MixtralMLP
(
nn
.
Module
):
def
__init__
(
self
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
num_experts
=
num_experts
self
.
ffn_dim
=
intermediate_size
self
.
hidden_dim
=
hidden_size
self
.
w1
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w2
=
ReplicatedLinear
(
self
.
ffn_dim
,
self
.
hidden_dim
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
w3
=
ReplicatedLinear
(
self
.
hidden_dim
,
self
.
ffn_dim
,
bias
=
False
,
linear_method
=
linear_method
)
# TODO: Use vllm's SiluAndMul
self
.
act_fn
=
nn
.
SiLU
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
w1_out
,
_
=
self
.
w1
(
hidden_states
)
w1_out
=
self
.
act_fn
(
w1_out
)
w3_out
,
_
=
self
.
w3
(
hidden_states
)
current_hidden_states
=
w1_out
*
w3_out
current_hidden_states
,
_
=
self
.
w2
(
current_hidden_states
)
return
current_hidden_states
class
MixtralMoE
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
):
super
().
__init__
()
self
.
config
=
config
self
.
rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
num_total_experts
=
config
.
num_local_experts
self
.
top_k
=
config
.
num_experts_per_tok
if
self
.
tp_size
>
self
.
num_total_experts
:
raise
ValueError
(
f
"Tensor parallel size
{
self
.
tp_size
}
is greater than "
f
"the number of experts
{
self
.
num_total_experts
}
."
)
# Split experts equally between ranks
self
.
expert_indicies
=
np
.
array_split
(
range
(
self
.
num_total_experts
),
self
.
tp_size
)[
self
.
rank
].
tolist
()
if
not
self
.
expert_indicies
:
raise
ValueError
(
f
"Rank
{
self
.
rank
}
has no experts assigned to it."
)
self
.
experts
=
nn
.
ModuleList
([
MixtralMLP
(
self
.
num_total_experts
,
config
.
hidden_size
,
config
.
intermediate_size
,
linear_method
=
linear_method
)
if
idx
in
self
.
expert_indicies
else
None
for
idx
in
range
(
self
.
num_total_experts
)
])
self
.
gate
=
ReplicatedLinear
(
config
.
hidden_size
,
self
.
num_total_experts
,
bias
=
False
,
linear_method
=
None
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
batch_size
,
sequence_length
,
hidden_dim
=
hidden_states
.
shape
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_dim
)
# router_logits: (batch * sequence_length, n_experts)
router_logits
,
_
=
self
.
gate
(
hidden_states
)
routing_weights
=
F
.
softmax
(
router_logits
,
dim
=
1
,
dtype
=
torch
.
float
)
routing_weights
,
selected_experts
=
torch
.
topk
(
routing_weights
,
self
.
top_k
,
dim
=-
1
)
routing_weights
/=
routing_weights
.
sum
(
dim
=-
1
,
keepdim
=
True
)
final_hidden_states
=
None
for
expert_idx
in
self
.
expert_indicies
:
expert_layer
=
self
.
experts
[
expert_idx
]
expert_mask
=
(
selected_experts
==
expert_idx
)
expert_weights
=
(
routing_weights
*
expert_mask
).
sum
(
dim
=-
1
,
keepdim
=
True
)
current_hidden_states
=
expert_layer
(
hidden_states
).
mul_
(
expert_weights
)
if
final_hidden_states
is
None
:
final_hidden_states
=
current_hidden_states
else
:
final_hidden_states
.
add_
(
current_hidden_states
)
return
tensor_model_parallel_all_reduce
(
final_hidden_states
).
view
(
batch_size
,
sequence_length
,
hidden_dim
)
class
MixtralAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
,
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
MixtralDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
MixtralAttention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
sliding_window
=
config
.
sliding_window
,
linear_method
=
linear_method
)
self
.
block_sparse_moe
=
MixtralMoE
(
config
=
config
,
linear_method
=
linear_method
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
block_sparse_moe
(
hidden_states
)
return
hidden_states
,
residual
class
MixtralModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
MixtralDecoderLayer
(
config
,
linear_method
=
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
residual
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
MixtralForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
MixtralConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
MixtralModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
,
fall_back_to_pt
=
False
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# Skip experts that are not assigned to this worker.
if
(
"block_sparse_moe.experts."
in
name
and
name
not
in
params_dict
):
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/phi
_1_5
.py
→
vllm/model_executor/models/phi.py
View file @
51679bbd
...
...
@@ -62,20 +62,6 @@ from vllm.sequence import SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
PhiEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
wte
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
):
return
self
.
wte
(
input_ids
)
class
PhiAttention
(
nn
.
Module
):
def
__init__
(
self
,
...
...
@@ -93,27 +79,22 @@ class PhiAttention(nn.Module):
tensor_model_parallel_world_size
)
# pylint: disable=C0103
self
.
Wqkv
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_size
,
self
.
total_num_heads
,
linear_method
=
linear_method
,
)
self
.
qkv_proj
=
QKVParallelLinear
(
config
.
hidden_size
,
self
.
hidden_size
,
self
.
head_size
,
self
.
total_num_heads
,
bias
=
Fals
e
,
bias
=
Tru
e
,
linear_method
=
linear_method
,
)
self
.
out_proj
=
RowParallelLinear
(
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
linear_method
=
linear_method
,
)
scaling
=
self
.
head_size
**-
0.5
rotary_dim
=
config
.
rotary_dim
rotary_dim
=
int
(
config
.
partial_rotary_factor
*
(
config
.
hidden_size
//
config
.
num_attention_heads
))
assert
rotary_dim
%
2
==
0
# pylint: disable=C0301
...
...
@@ -136,12 +117,12 @@ class PhiAttention(nn.Module):
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
W
qkv
(
hidden_states
)
qkv
,
_
=
self
.
qkv
_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
position_ids
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
out_proj
(
attn_output
)
output
,
_
=
self
.
dense
(
attn_output
)
return
output
...
...
@@ -166,8 +147,7 @@ class PhiMLP(nn.Module):
linear_method
=
linear_method
,
)
quant_config
=
getattr
(
linear_method
,
"quant_config"
,
None
)
self
.
act
=
get_act_fn
(
config
.
activation_function
,
quant_config
,
n_inner
)
self
.
act
=
get_act_fn
(
config
.
hidden_act
,
quant_config
,
n_inner
)
def
forward
(
self
,
hidden_states
):
hidden_states
,
_
=
self
.
fc1
(
hidden_states
)
...
...
@@ -182,9 +162,9 @@ class PhiLayer(nn.Module):
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
):
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
ilon
)
self
.
mixer
=
PhiAttention
(
config
,
linear_method
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
self_attn
=
PhiAttention
(
config
,
linear_method
)
self
.
mlp
=
PhiMLP
(
config
,
linear_method
)
def
forward
(
...
...
@@ -195,8 +175,8 @@ class PhiLayer(nn.Module):
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
residual
=
hidden_states
hidden_states
=
self
.
ln
(
hidden_states
)
attn_outputs
=
self
.
mixer
(
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
attn_outputs
=
self
.
self_attn
(
position_ids
=
position_ids
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
...
...
@@ -215,11 +195,14 @@ class PhiModel(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
embd
=
PhiEmbedding
(
config
)
self
.
h
=
nn
.
ModuleList
([
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
layers
=
nn
.
ModuleList
([
PhiLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
final_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
def
forward
(
self
,
...
...
@@ -228,27 +211,19 @@ class PhiModel(nn.Module):
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
emb
d
(
input_ids
)
hidden_states
=
self
.
emb
ed_tokens
(
input_ids
)
for
i
in
range
(
self
.
config
.
num_hidden_layers
):
layer
=
self
.
h
[
i
]
layer
=
self
.
layers
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
)
return
hidden_states
class
PhiCausalLMHead
(
nn
.
Module
):
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
def
__init__
(
self
,
config
:
PretrainedConfig
):
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
linear
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
return
hidden_states
class
PhiForCausalLM
(
nn
.
Module
):
...
...
@@ -260,8 +235,11 @@ class PhiForCausalLM(nn.Module):
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
transformer
=
PhiModel
(
config
,
linear_method
)
self
.
lm_head
=
PhiCausalLMHead
(
config
)
self
.
model
=
PhiModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
bias
=
True
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
@@ -271,9 +249,9 @@ class PhiForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
hidden_states
=
self
.
lm_head
.
ln
(
hidden_states
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
...
...
@@ -281,7 +259,7 @@ class PhiForCausalLM(nn.Module):
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
head
=
self
.
lm_head
.
linear
head
=
self
.
lm_head
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
return
next_tokens
...
...
@@ -291,17 +269,37 @@ class PhiForCausalLM(nn.Module):
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
)
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
# pylint: disable=E1136
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/qwen2.py
0 → 100644
View file @
51679bbd
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
# Copyright 2024 The Qwen team.
# Copyright 2023 The vLLM team.
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
Qwen2Config
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
Qwen2MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Qwen2Attention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
max_position
:
int
=
4096
*
32
,
rope_theta
:
float
=
10000
,
use_sliding_window
:
bool
=
False
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
sliding_window
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
num_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_kv_heads
=
num_kv_heads
if
self
.
total_num_kv_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_kv_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_kv_heads
==
0
self
.
num_kv_heads
=
max
(
1
,
self
.
total_num_kv_heads
//
tp_size
)
self
.
head_dim
=
hidden_size
//
self
.
total_num_heads
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
sliding_window
=
sliding_window
if
use_sliding_window
else
None
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_kv_heads
,
bias
=
True
,
linear_method
=
linear_method
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
,
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
head_dim
,
max_position
=
max_position
,
base
=
self
.
rope_theta
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
sliding_window
=
self
.
sliding_window
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
Qwen2DecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen2Config
,
layer_idx
:
int
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
# Requires transformers > 4.32.0
rope_theta
=
getattr
(
config
,
"rope_theta"
,
1000000
)
use_sliding_window
=
config
.
use_sliding_window
and
layer_idx
<
config
.
max_window_layers
self
.
self_attn
=
Qwen2Attention
(
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
num_kv_heads
=
config
.
num_key_value_heads
,
rope_theta
=
rope_theta
,
use_sliding_window
=
use_sliding_window
,
linear_method
=
linear_method
,
sliding_window
=
config
.
sliding_window
)
self
.
mlp
=
Qwen2MLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
linear_method
=
linear_method
,
)
self
.
input_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
residual
:
Optional
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
if
residual
is
None
:
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
else
:
hidden_states
,
residual
=
self
.
input_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
# Fully Connected
hidden_states
,
residual
=
self
.
post_attention_layernorm
(
hidden_states
,
residual
)
hidden_states
=
self
.
mlp
(
hidden_states
)
return
hidden_states
,
residual
class
Qwen2Model
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
Qwen2DecoderLayer
(
config
,
layer_idx
,
linear_method
)
for
layer_idx
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
residual
=
None
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
residual
,
)
hidden_states
,
_
=
self
.
norm
(
hidden_states
,
residual
)
return
hidden_states
class
Qwen2ForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
Qwen2Config
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
Qwen2Model
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/models/stablelm.py
0 → 100644
View file @
51679bbd
# coding=utf-8
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This code is based off the following work:
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py
# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights."""
from
typing
import
List
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
vllm.model_executor.layers.attention
import
PagedAttention
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
MergedColumnParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
class
StablelmMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
intermediate_size
=
config
.
intermediate_size
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
config
.
hidden_size
,
[
config
.
intermediate_size
]
*
2
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
down_proj
=
RowParallelLinear
(
config
.
intermediate_size
,
config
.
hidden_size
,
bias
=
False
)
self
.
act_fn
=
SiluAndMul
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
=
self
.
act_fn
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
StablelmAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
total_num_key_value_heads
=
config
.
num_key_value_heads
if
self
.
total_num_key_value_heads
>=
tp_size
:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert
self
.
total_num_key_value_heads
%
tp_size
==
0
else
:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert
tp_size
%
self
.
total_num_key_value_heads
==
0
self
.
num_key_value_heads
=
max
(
1
,
self
.
total_num_key_value_heads
//
tp_size
)
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
rotary_ndims
=
int
(
self
.
head_dim
*
self
.
config
.
rope_pct
)
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_key_value_heads
*
self
.
head_dim
self
.
qkv_bias
=
getattr
(
config
,
"use_qkv_bias"
,
False
)
if
(
self
.
head_dim
*
self
.
num_heads
*
tp_size
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
self
.
num_heads
}
)."
)
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
hidden_size
,
self
.
head_dim
,
self
.
total_num_heads
,
self
.
total_num_key_value_heads
,
self
.
qkv_bias
,
linear_method
=
linear_method
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
,
linear_method
=
linear_method
)
self
.
rotary_ndims
=
int
(
self
.
head_dim
*
self
.
config
.
rope_pct
)
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
self
.
rotary_ndims
,
max_position
=
self
.
config
.
max_position_embeddings
,
base
=
self
.
config
.
rope_theta
,
)
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
self
.
scaling
,
num_kv_heads
=
self
.
num_key_value_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
k_cache
,
v_cache
=
kv_cache
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
)
output
,
_
=
self
.
o_proj
(
attn_output
)
return
output
class
StablelmDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
self_attn
=
StablelmAttention
(
config
)
self
.
mlp
=
StablelmMLP
(
config
,
linear_method
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
self
.
post_attention_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Self Attention
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
hidden_states
=
hidden_states
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
return
hidden_states
,
residual
class
StableLMEpochModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
)
->
None
:
super
().
__init__
()
# self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
)
self
.
layers
=
nn
.
ModuleList
([
StablelmDecoderLayer
(
config
,
linear_method
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_eps
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embed_tokens
(
input_ids
)
for
i
in
range
(
len
(
self
.
layers
)):
layer
=
self
.
layers
[
i
]
hidden_states
,
residual
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
)
hidden_states
=
self
.
norm
(
hidden_states
)
return
hidden_states
class
StablelmForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
PretrainedConfig
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
)
->
None
:
super
().
__init__
()
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
model
=
StableLMEpochModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
load_format
:
str
=
"auto"
,
revision
:
Optional
[
str
]
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
named_parameters
())
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
load_format
,
revision
):
if
"rotary_emb.inv_freq"
in
name
:
continue
if
(
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
):
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
continue
for
(
param_name
,
weight_name
,
shard_id
)
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
vllm/model_executor/parallel_utils/communication_op.py
View file @
51679bbd
from
collections
import
namedtuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
torch.distributed
import
ProcessGroup
import
torch
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
...
@@ -5,23 +10,34 @@ from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_group
,
)
from
vllm.model_executor.parallel_utils.custom_all_reduce
import
custom_all_reduce
def
tensor_model_parallel_all_reduce
(
input_
)
:
def
tensor_model_parallel_all_reduce
(
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""All-reduce the input tensor across model parallel group.
NOTE: This operation is applied in-place on the input tensor.
NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
# All-reduce.
out
=
custom_all_reduce
(
input_
)
if
out
is
not
None
:
return
out
torch
.
distributed
.
all_reduce
(
input_
,
group
=
get_tensor_model_parallel_group
())
return
input_
def
tensor_model_parallel_all_gather
(
input_
,
dim
=-
1
):
def
tensor_model_parallel_all_gather
(
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""All-gather the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
...
...
@@ -48,7 +64,9 @@ def tensor_model_parallel_all_gather(input_, dim=-1):
return
output_tensor
def
tensor_model_parallel_gather
(
input_
,
dst
=
0
,
dim
=-
1
):
def
tensor_model_parallel_gather
(
input_
:
torch
.
Tensor
,
dst
:
int
=
0
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
"""Gather the input tensor across model parallel group.
NOTE: We assume that the input tensor is on the same device across
...
...
@@ -80,27 +98,101 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1):
return
output_tensor
def
broadcast
(
input_
,
src
=
0
):
def
broadcast
(
input_
:
torch
.
Tensor
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input tensor."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
group
=
group
or
torch
.
distributed
.
group
.
WORLD
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
input_
# Broadcast.
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
)
torch
.
distributed
.
broadcast
(
input_
,
src
=
src
,
group
=
group
)
return
input_
def
broadcast_object_list
(
obj_list
,
src
=
0
):
def
broadcast_object_list
(
obj_list
:
List
[
Any
],
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
):
"""Broadcast the input object list."""
world_size
=
torch
.
distributed
.
get_world_size
()
assert
0
<=
src
<
world_size
,
f
"Invalid src rank (
{
src
}
)"
group
=
group
or
torch
.
distributed
.
group
.
WORLD
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
obj_list
# Broadcast.
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
)
torch
.
distributed
.
broadcast_object_list
(
obj_list
,
src
=
src
,
group
=
group
)
return
obj_list
TensorMetadata
=
namedtuple
(
"TensorMetadata"
,
[
"dtype"
,
"size"
])
def
broadcast_tensor_dict
(
tensor_dict
:
Optional
[
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]]
=
None
,
src
:
int
=
0
,
group
:
Optional
[
ProcessGroup
]
=
None
,
)
->
Dict
[
Any
,
Union
[
torch
.
Tensor
,
Any
]]:
"""Broadcast the input tensor dictionary."""
group
=
group
or
torch
.
distributed
.
group
.
WORLD
ranks
=
torch
.
distributed
.
get_process_group_ranks
(
group
)
assert
src
in
ranks
,
f
"Invalid src rank (
{
src
}
)"
# Bypass the function if we are using only 1 GPU.
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
if
world_size
==
1
:
return
tensor_dict
rank
=
torch
.
distributed
.
get_rank
()
if
rank
==
src
:
assert
isinstance
(
tensor_dict
,
dict
),
(
f
"Expecting a dictionary, got
{
type
(
tensor_dict
)
}
"
)
metadata_list
=
[]
for
key
,
value
in
tensor_dict
.
items
():
if
isinstance
(
value
,
torch
.
Tensor
):
assert
value
.
is_cuda
,
(
f
"Tensor
{
key
}
:
{
value
}
is not on cuda. Currently we only "
f
"support broadcasting tensors on cuda."
)
metadata_list
.
append
(
(
key
,
TensorMetadata
(
value
.
dtype
,
value
.
size
())))
else
:
metadata_list
.
append
((
key
,
value
))
torch
.
distributed
.
broadcast_object_list
([
metadata_list
],
src
=
src
,
group
=
group
)
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
tensor_dict
[
key
]
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
)
else
:
recv_metadata_list
=
[
None
]
torch
.
distributed
.
broadcast_object_list
(
recv_metadata_list
,
src
=
src
,
group
=
group
)
metadata_list
=
recv_metadata_list
[
0
]
tensor_dict
=
{}
async_handles
=
[]
for
key
,
value
in
metadata_list
:
if
isinstance
(
value
,
TensorMetadata
):
tensor
=
torch
.
empty
(
value
.
size
,
dtype
=
value
.
dtype
,
device
=
"cuda"
)
async_handle
=
torch
.
distributed
.
broadcast
(
tensor
,
src
=
src
,
async_op
=
True
,
group
=
group
)
async_handles
.
append
(
async_handle
)
tensor_dict
[
key
]
=
tensor
else
:
tensor_dict
[
key
]
=
value
for
async_handle
in
async_handles
:
async_handle
.
wait
()
return
tensor_dict
vllm/model_executor/parallel_utils/custom_all_reduce.py
0 → 100644
View file @
51679bbd
from
contextlib
import
contextmanager
from
typing
import
Optional
import
torch
import
torch.distributed
as
dist
from
vllm.logger
import
init_logger
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_rank
)
try
:
from
vllm._C
import
custom_ar
import
pynvml
except
ImportError
:
# For AMD GPUs
custom_ar
=
None
pynvml
=
None
logger
=
init_logger
(
__name__
)
_CA_HANDLE
=
None
_IS_CAPTURING
=
False
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
def
init_custom_ar
()
->
None
:
global
_CA_HANDLE
if
_CA_HANDLE
is
not
None
:
return
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
_SUPPORTED_WORLD_SIZES
:
logger
.
warn
(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To slience this warning, specify"
"disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
_SUPPORTED_WORLD_SIZES
))
return
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warn
(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability. To slience this warning, specify"
"disable_custom_all_reduce=True explicitly."
)
return
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
)
def
begin_capture
()
->
None
:
global
_IS_CAPTURING
_IS_CAPTURING
=
True
def
end_capture
()
->
None
:
global
_IS_CAPTURING
_IS_CAPTURING
=
False
def
is_capturing
()
->
bool
:
return
_IS_CAPTURING
and
_CA_HANDLE
is
not
None
def
get_handle
()
->
Optional
[
"CustomAllreduce"
]:
return
_CA_HANDLE
@
contextmanager
def
capture
():
try
:
begin_capture
()
yield
finally
:
end_capture
()
handle
=
get_handle
()
if
handle
is
not
None
:
handle
.
register_graph_buffers
()
def
custom_all_reduce
(
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
ca_handle
=
get_handle
()
# when custom allreduce is disabled, this will be None
if
ca_handle
is
None
:
return
if
is_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
if
ca_handle
.
should_custom_ar
(
input
):
return
ca_handle
.
all_reduce_reg
(
input
)
else
:
if
ca_handle
.
should_custom_ar
(
input
):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return
torch
.
empty_like
(
input
)
else
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if
ca_handle
.
should_custom_ar
(
input
):
return
ca_handle
.
all_reduce_unreg
(
input
)
@
contextmanager
def
_nvml
():
try
:
pynvml
.
nvmlInit
()
yield
finally
:
pynvml
.
nvmlShutdown
()
# query if the set of gpus are fully connected by nvlink (1 hop)
@
_nvml
()
def
_is_full_nvlink
(
rank
,
world_size
):
handle
=
pynvml
.
nvmlDeviceGetHandleByIndex
(
rank
)
for
i
in
range
(
world_size
):
if
i
!=
rank
:
try
:
link_state
=
pynvml
.
nvmlDeviceGetNvLinkState
(
handle
,
i
)
if
not
link_state
:
return
False
except
pynvml
.
NVMLError
as
error
:
logger
.
info
(
f
"NVLink detection failed with message
\"
{
str
(
error
)
}
\"
. "
"This is normal if your machine has no NVLink equipped"
)
return
False
return
True
def
_can_p2p
(
rank
:
int
,
world_size
:
int
)
->
bool
:
for
i
in
range
(
world_size
):
if
i
==
rank
:
continue
if
not
torch
.
cuda
.
can_device_access_peer
(
rank
,
i
):
return
False
return
True
class
CustomAllreduce
:
# max_size: max supported allreduce size
def
__init__
(
self
,
rank
,
world_size
,
max_size
=
8192
*
1024
)
->
None
:
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
self
.
meta
=
torch
.
zeros
(
custom_ar
.
meta_size
()
+
max_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
max_size
=
max_size
self
.
world_size
=
world_size
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
self
.
full_nvlink
=
_is_full_nvlink
(
rank
,
world_size
)
self
.
_ptr
=
custom_ar
.
init_custom_ar
(
self
.
meta
,
self
.
rank_data
,
handles
,
offsets
,
rank
,
self
.
full_nvlink
)
self
.
fast_cond
=
self
.
full_nvlink
or
world_size
<=
2
self
.
register_buffer
(
self
.
buffer
)
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
data
=
inp
.
untyped_storage
().
_share_cuda_
()
shard_data
=
(
data
[
1
],
# ipc handle to base ptr
data
[
3
],
# offset of base ptr
)
return
self
.
_gather_ipc_meta
(
shard_data
)
def
_gather_ipc_meta
(
self
,
shard_data
):
all_data
=
[
None
]
*
self
.
world_size
dist
.
all_gather_object
(
all_data
,
shard_data
)
handles
=
[]
offsets
=
[]
for
i
in
range
(
len
(
all_data
)):
handles
.
append
(
all_data
[
i
][
0
])
offsets
.
append
(
all_data
[
i
][
1
])
return
handles
,
offsets
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
handles
,
offsets
=
self
.
_get_ipc_meta
(
inp
)
custom_ar
.
register_buffer
(
self
.
_ptr
,
inp
,
handles
,
offsets
)
def
register_graph_buffers
(
self
):
handle
,
offset
=
custom_ar
.
get_graph_buffer_ipc_meta
(
self
.
_ptr
)
handles
,
offsets
=
self
.
_gather_ipc_meta
((
bytes
(
handle
),
offset
))
logger
.
info
(
"Registering %d cuda graph addresses"
,
len
(
offset
))
custom_ar
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
return
custom_ar
.
should_custom_ar
(
inp
,
self
.
max_size
,
self
.
world_size
,
self
.
full_nvlink
)
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
def
all_reduce_reg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
custom_ar
.
all_reduce_reg
(
self
.
_ptr
,
inp
,
out
)
return
out
# all reduce, assuming inp tensor is NOT IPC registered
def
all_reduce_unreg
(
self
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
=
None
):
if
out
is
None
:
out
=
torch
.
empty_like
(
inp
)
custom_ar
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
return
out
def
close
(
self
):
if
self
.
_ptr
:
custom_ar
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
def
__del__
(
self
):
self
.
close
()
vllm/model_executor/parallel_utils/parallel_state.py
View file @
51679bbd
...
...
@@ -83,6 +83,31 @@ def initialize_model_parallel(
_PIPELINE_GLOBAL_RANKS
=
ranks
def
ensure_model_parallel_initialized
(
tensor_model_parallel_size
:
int
,
pipeline_model_parallel_size
:
int
,
)
->
None
:
"""Helper to initialize model parallel groups if they are not initialized,
or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
values if the model parallel groups are initialized.
"""
if
not
model_parallel_is_initialized
():
initialize_model_parallel
(
tensor_model_parallel_size
,
pipeline_model_parallel_size
)
return
assert
(
get_tensor_model_parallel_world_size
()
==
tensor_model_parallel_size
),
(
"tensor parallel group already initialized, but of unexpected size: "
f
"
{
get_tensor_model_parallel_world_size
()
=
}
vs. "
f
"
{
tensor_model_parallel_size
=
}
"
)
assert
(
get_pipeline_model_parallel_world_size
(
)
==
pipeline_model_parallel_size
),
(
"pipeline parallel group already initialized, but of unexpected size: "
f
"
{
get_pipeline_model_parallel_world_size
()
=
}
vs. "
f
"
{
pipeline_model_parallel_size
=
}
"
)
def
model_parallel_is_initialized
():
"""Check if tensor and pipeline parallel groups are initialized."""
return
(
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
...
...
@@ -92,7 +117,7 @@ def model_parallel_is_initialized():
def
get_tensor_model_parallel_group
():
"""Get the tensor model parallel group the caller rank belongs to."""
assert
_TENSOR_MODEL_PARALLEL_GROUP
is
not
None
,
(
"ten
o
sr model parallel group is not initialized"
)
"tens
o
r model parallel group is not initialized"
)
return
_TENSOR_MODEL_PARALLEL_GROUP
...
...
@@ -170,10 +195,14 @@ def get_pipeline_model_parallel_prev_rank():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none
and destroy them
."""
global
_TENSOR_MODEL_PARALLEL_GROUP
if
_TENSOR_MODEL_PARALLEL_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_TENSOR_MODEL_PARALLEL_GROUP
)
_TENSOR_MODEL_PARALLEL_GROUP
=
None
global
_PIPELINE_MODEL_PARALLEL_GROUP
if
_PIPELINE_MODEL_PARALLEL_GROUP
:
torch
.
distributed
.
destroy_process_group
(
_PIPELINE_MODEL_PARALLEL_GROUP
)
_PIPELINE_MODEL_PARALLEL_GROUP
=
None
global
_PIPELINE_GLOBAL_RANKS
_PIPELINE_GLOBAL_RANKS
=
None
vllm/model_executor/weight_utils.py
View file @
51679bbd
"""Utilities for downloading and initializing model weights."""
import
filelock
import
glob
import
fnmatch
import
json
import
os
from
collections
import
defaultdict
from
typing
import
Any
,
Iterator
,
List
,
Optional
,
Tuple
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
,
HfFileSystem
import
numpy
as
np
from
safetensors.torch
import
load_file
,
save_file
,
safe_open
import
torch
...
...
@@ -149,6 +150,18 @@ def prepare_hf_model_weights(
allow_patterns
+=
[
"*.pt"
]
if
not
is_local
:
# Before we download we look at that is available:
fs
=
HfFileSystem
()
file_list
=
fs
.
ls
(
model_name_or_path
,
detail
=
False
,
revision
=
revision
)
# depending on what is available we download different things
for
pattern
in
allow_patterns
:
matching
=
fnmatch
.
filter
(
file_list
,
pattern
)
if
len
(
matching
)
>
0
:
allow_patterns
=
[
pattern
]
break
logger
.
info
(
f
"Using model weights format
{
allow_patterns
}
"
)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with
get_lock
(
model_name_or_path
,
cache_dir
):
...
...
vllm/outputs.py
View file @
51679bbd
...
...
@@ -2,6 +2,7 @@ from typing import List, Optional
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SequenceGroup
,
SequenceStatus
)
from
vllm.lora.request
import
LoRARequest
class
CompletionOutput
:
...
...
@@ -16,6 +17,7 @@ class CompletionOutput:
logprobs: The log probabilities of the top probability words at each
position if the logprobs are requested.
finish_reason: The reason why the sequence is finished.
lora_request: The LoRA request that was used to generate the output.
"""
def
__init__
(
...
...
@@ -26,6 +28,7 @@ class CompletionOutput:
cumulative_logprob
:
float
,
logprobs
:
Optional
[
SampleLogprobs
],
finish_reason
:
Optional
[
str
]
=
None
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
self
.
index
=
index
self
.
text
=
text
...
...
@@ -33,6 +36,7 @@ class CompletionOutput:
self
.
cumulative_logprob
=
cumulative_logprob
self
.
logprobs
=
logprobs
self
.
finish_reason
=
finish_reason
self
.
lora_request
=
lora_request
def
finished
(
self
)
->
bool
:
return
self
.
finish_reason
is
not
None
...
...
@@ -56,6 +60,7 @@ class RequestOutput:
prompt_logprobs: The log probabilities to return per prompt token.
outputs: The output sequences of the request.
finished: Whether the whole request is finished.
lora_request: The LoRA request that was used to generate the output.
"""
def
__init__
(
...
...
@@ -66,6 +71,7 @@ class RequestOutput:
prompt_logprobs
:
Optional
[
PromptLogprobs
],
outputs
:
List
[
CompletionOutput
],
finished
:
bool
,
lora_request
:
Optional
[
LoRARequest
]
=
None
,
)
->
None
:
self
.
request_id
=
request_id
self
.
prompt
=
prompt
...
...
@@ -73,6 +79,7 @@ class RequestOutput:
self
.
prompt_logprobs
=
prompt_logprobs
self
.
outputs
=
outputs
self
.
finished
=
finished
self
.
lora_request
=
lora_request
@
classmethod
def
from_seq_group
(
cls
,
seq_group
:
SequenceGroup
)
->
"RequestOutput"
:
...
...
@@ -108,8 +115,13 @@ class RequestOutput:
prompt_token_ids
=
seq_group
.
prompt_token_ids
prompt_logprobs
=
seq_group
.
prompt_logprobs
finished
=
seq_group
.
is_finished
()
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
prompt_logprobs
,
outputs
,
finished
)
return
cls
(
seq_group
.
request_id
,
prompt
,
prompt_token_ids
,
prompt_logprobs
,
outputs
,
finished
,
lora_request
=
seq_group
.
lora_request
)
def
__repr__
(
self
)
->
str
:
return
(
f
"RequestOutput(request_id=
{
self
.
request_id
}
, "
...
...
@@ -117,4 +129,5 @@ class RequestOutput:
f
"prompt_token_ids=
{
self
.
prompt_token_ids
}
, "
f
"prompt_logprobs=
{
self
.
prompt_logprobs
}
, "
f
"outputs=
{
self
.
outputs
}
, "
f
"finished=
{
self
.
finished
}
)"
)
f
"finished=
{
self
.
finished
}
, "
f
"lora_request=
{
self
.
lora_request
}
)"
)
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment