Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1c0fc39
Unverified
Commit
f1c0fc39
authored
Mar 21, 2024
by
Roy
Committed by
GitHub
Mar 20, 2024
Browse files
Migrate `logits` computation and gather to `model_runner` (#3233)
parent
6e435de7
Changes
35
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
171 additions
and
63 deletions
+171
-63
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+14
-4
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+12
-4
vllm/model_executor/models/mixtral_quant.py
vllm/model_executor/models/mixtral_quant.py
+11
-4
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+11
-4
vllm/model_executor/models/neuron/llama.py
vllm/model_executor/models/neuron/llama.py
+11
-4
vllm/model_executor/models/neuron/mistral.py
vllm/model_executor/models/neuron/mistral.py
+11
-4
vllm/model_executor/models/olmo.py
vllm/model_executor/models/olmo.py
+11
-4
vllm/model_executor/models/opt.py
vllm/model_executor/models/opt.py
+11
-4
vllm/model_executor/models/orion.py
vllm/model_executor/models/orion.py
+11
-4
vllm/model_executor/models/phi.py
vllm/model_executor/models/phi.py
+11
-5
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+11
-4
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+15
-9
vllm/model_executor/models/stablelm.py
vllm/model_executor/models/stablelm.py
+11
-4
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+12
-4
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+8
-1
No files found.
vllm/model_executor/models/llama.py
View file @
f1c0fc39
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
...
@@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
...
@@ -325,7 +326,11 @@ class LlamaForCausalLM(nn.Module):
# compatibility
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
)
self
.
sampler
=
Sampler
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
logit_scale
=
getattr
(
config
,
"logit_scale"
,
1.0
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
,
logit_scale
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
...
@@ -338,13 +343,18 @@ class LlamaForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/mixtral.py
View file @
f1c0fc39
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
...
@@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
...
@@ -369,7 +370,9 @@ class MixtralForCausalLM(nn.Module):
# compatibility
# compatibility
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
if
not
lora_config
else
lora_config
.
lora_vocab_padding_size
,
)
)
self
.
sampler
=
Sampler
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
...
@@ -382,13 +385,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/mixtral_quant.py
View file @
f1c0fc39
...
@@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -39,6 +39,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
...
@@ -344,7 +345,8 @@ class MixtralForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
MixtralModel
(
config
,
linear_method
)
self
.
model
=
MixtralModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
...
@@ -357,13 +359,18 @@ class MixtralForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/mpt.py
View file @
f1c0fc39
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
...
@@ -259,7 +260,8 @@ class MPTForCausalLM(nn.Module):
self
.
transformer
=
MPTModel
(
config
,
linear_method
)
self
.
transformer
=
MPTModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
...
@@ -272,13 +274,18 @@ class MPTForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/neuron/llama.py
View file @
f1c0fc39
...
@@ -7,6 +7,7 @@ from torch import nn
...
@@ -7,6 +7,7 @@ from torch import nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
...
@@ -25,7 +26,8 @@ class LlamaForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
None
self
.
model
=
None
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
...
@@ -45,13 +47,18 @@ class LlamaForCausalLM(nn.Module):
start_ids
=
seq_ids
.
flatten
())
start_ids
=
seq_ids
.
flatten
())
return
logits
return
logits
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
model
.
chkpt_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
chkpt_model
.
lm_head
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
hidden_states
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/neuron/mistral.py
View file @
f1c0fc39
...
@@ -6,6 +6,7 @@ from torch import nn
...
@@ -6,6 +6,7 @@ from torch import nn
from
transformers
import
MistralConfig
from
transformers
import
MistralConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
...
@@ -26,7 +27,8 @@ class MistralForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
None
self
.
model
=
None
self
.
lm_head
=
None
self
.
lm_head
=
None
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
...
@@ -48,13 +50,18 @@ class MistralForCausalLM(nn.Module):
start_ids
=
seq_ids
)
start_ids
=
seq_ids
)
return
logits
return
logits
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
model
.
chkpt_model
.
lm_head
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
chkpt_model
.
lm_head
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
hidden_states
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/olmo.py
View file @
f1c0fc39
...
@@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
...
@@ -51,6 +51,7 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear
,
RowParallelLinear
,
)
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
...
@@ -336,7 +337,8 @@ class OLMoForCausalLM(nn.Module):
self
.
lm_head_weight
=
(
self
.
model
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
(
self
.
model
.
transformer
.
wte
.
weight
if
config
.
weight_tying
else
if
config
.
weight_tying
else
self
.
model
.
transformer
.
ff_out
.
weight
)
self
.
model
.
transformer
.
ff_out
.
weight
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
...
@@ -353,13 +355,18 @@ class OLMoForCausalLM(nn.Module):
)
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
def
load_weights
(
...
...
vllm/model_executor/models/opt.py
View file @
f1c0fc39
...
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
...
@@ -292,7 +293,8 @@ class OPTForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
OPTModel
(
config
,
linear_method
)
self
.
model
=
OPTModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
lm_head_weight
=
self
.
model
.
decoder
.
embed_tokens
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
...
@@ -305,13 +307,18 @@ class OPTForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/orion.py
View file @
f1c0fc39
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
...
@@ -256,7 +257,8 @@ class OrionForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
OrionModel
(
config
,
linear_method
)
self
.
model
=
OrionModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
...
@@ -269,13 +271,18 @@ class OrionForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/phi.py
View file @
f1c0fc39
...
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -49,6 +49,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
...
@@ -240,7 +241,8 @@ class PhiForCausalLM(nn.Module):
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
bias
=
True
)
bias
=
True
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
...
@@ -254,14 +256,18 @@ class PhiForCausalLM(nn.Module):
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
head
=
self
.
lm_head
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
next_tokens
=
self
.
sampler
(
head
.
weight
,
hidden_states
,
sampling_metadata
,
head
.
bias
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/qwen.py
View file @
f1c0fc39
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
...
@@ -230,7 +231,8 @@ class QWenLMHeadModel(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
QWenModel
(
config
,
linear_method
)
self
.
transformer
=
QWenModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
...
@@ -243,13 +245,18 @@ class QWenLMHeadModel(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/qwen2.py
View file @
f1c0fc39
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -37,6 +37,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -300,11 +301,15 @@ class Qwen2ForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
Qwen2Model
(
config
,
linear_method
)
self
.
model
=
Qwen2Model
(
config
,
linear_method
)
if
not
config
.
tie_word_embeddings
:
if
config
.
tie_word_embeddings
:
self
.
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
config
.
hidden_size
)
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
...
@@ -317,17 +322,18 @@ class Qwen2ForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
if
self
.
config
.
tie_word_embeddings
:
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
lm_head_weight
=
self
.
model
.
embed_tokens
.
weight
else
:
lm_head_weight
=
self
.
lm_head
.
weight
next_tokens
=
self
.
sampler
(
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/stablelm.py
View file @
f1c0fc39
...
@@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -33,6 +33,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
...
@@ -238,7 +239,8 @@ class StablelmForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
StableLMEpochModel
(
config
,
linear_method
)
self
.
model
=
StableLMEpochModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
...
@@ -251,13 +253,18 @@ class StablelmForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/starcoder2.py
View file @
f1c0fc39
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -32,6 +32,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
VocabParallelEmbedding
,
ParallelLMHead
,
DEFAULT_VOCAB_PADDING_SIZE
)
...
@@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
...
@@ -254,7 +255,9 @@ class Starcoder2ForCausalLM(nn.Module):
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
padding_size
=
DEFAULT_VOCAB_PADDING_SIZE
,
)
)
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
lm_head_weight
=
self
.
lm_head
.
weight
self
.
sampler
=
Sampler
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
self
.
unpadded_vocab_size
,
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
...
@@ -267,13 +270,18 @@ class Starcoder2ForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/worker/model_runner.py
View file @
f1c0fc39
...
@@ -613,9 +613,16 @@ class ModelRunner:
...
@@ -613,9 +613,16 @@ class ModelRunner:
input_metadata
=
input_metadata
,
input_metadata
=
input_metadata
,
)
)
# Compute the logits.
logits
=
self
.
model
.
compute_logits
(
hidden_states
,
sampling_metadata
)
# Only perform sampling in the driver worker.
if
not
sampling_metadata
.
perform_sampling
:
return
None
# Sample the next token.
# Sample the next token.
output
=
self
.
model
.
sample
(
output
=
self
.
model
.
sample
(
hidden_states
=
hidden_state
s
,
logits
=
logit
s
,
sampling_metadata
=
sampling_metadata
,
sampling_metadata
=
sampling_metadata
,
)
)
return
output
return
output
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment