Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f1c0fc39
Unverified
Commit
f1c0fc39
authored
Mar 21, 2024
by
Roy
Committed by
GitHub
Mar 20, 2024
Browse files
Migrate `logits` computation and gather to `model_runner` (#3233)
parent
6e435de7
Changes
35
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
406 additions
and
243 deletions
+406
-243
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-0
tests/lora/conftest.py
tests/lora/conftest.py
+5
-2
tests/lora/test_layers.py
tests/lora/test_layers.py
+36
-30
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+20
-75
tests/test_logits_processor.py
tests/test_logits_processor.py
+94
-0
vllm/lora/layers.py
vllm/lora/layers.py
+12
-8
vllm/lora/models.py
vllm/lora/models.py
+8
-5
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+106
-0
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+1
-80
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+11
-4
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+11
-4
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+11
-4
vllm/model_executor/models/deepseek.py
vllm/model_executor/models/deepseek.py
+11
-4
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+11
-4
vllm/model_executor/models/gemma.py
vllm/model_executor/models/gemma.py
+11
-4
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+11
-3
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+11
-4
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+11
-4
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+11
-4
vllm/model_executor/models/internlm2.py
vllm/model_executor/models/internlm2.py
+11
-4
No files found.
.buildkite/test-pipeline.yaml
View file @
f1c0fc39
...
@@ -49,6 +49,9 @@ steps:
...
@@ -49,6 +49,9 @@ steps:
-
label
:
Samplers Test
-
label
:
Samplers Test
command
:
pytest -v -s samplers
command
:
pytest -v -s samplers
-
label
:
LogitsProcessor Test
command
:
pytest -v -s test_logits_processor.py
-
label
:
Worker Test
-
label
:
Worker Test
command
:
pytest -v -s worker
command
:
pytest -v -s worker
...
...
tests/lora/conftest.py
View file @
f1c0fc39
...
@@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
...
@@ -13,6 +13,7 @@ from huggingface_hub import snapshot_download
import
vllm
import
vllm
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
...
@@ -85,7 +86,8 @@ def dummy_model() -> nn.Module:
(
"outact"
,
nn
.
Sigmoid
()),
(
"outact"
,
nn
.
Sigmoid
()),
# Special handling for lm_head & sampler
# Special handling for lm_head & sampler
(
"lm_head"
,
ParallelLMHead
(
512
,
10
)),
(
"lm_head"
,
ParallelLMHead
(
512
,
10
)),
(
"sampler"
,
Sampler
(
512
))
(
"logits_processor"
,
LogitsProcessor
(
512
)),
(
"sampler"
,
Sampler
())
]))
]))
model
.
config
=
MagicMock
()
model
.
config
=
MagicMock
()
return
model
return
model
...
@@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
...
@@ -110,7 +112,8 @@ def dummy_model_gate_up() -> nn.Module:
(
"outact"
,
nn
.
Sigmoid
()),
(
"outact"
,
nn
.
Sigmoid
()),
# Special handling for lm_head & sampler
# Special handling for lm_head & sampler
(
"lm_head"
,
ParallelLMHead
(
512
,
10
)),
(
"lm_head"
,
ParallelLMHead
(
512
,
10
)),
(
"sampler"
,
Sampler
(
512
))
(
"logits_processor"
,
LogitsProcessor
(
512
)),
(
"sampler"
,
Sampler
())
]))
]))
model
.
config
=
MagicMock
()
model
.
config
=
MagicMock
()
return
model
return
model
...
...
tests/lora/test_layers.py
View file @
f1c0fc39
...
@@ -13,14 +13,14 @@ from vllm.lora.layers import (
...
@@ -13,14 +13,14 @@ from vllm.lora.layers import (
QKVParallelLinearWithLora
,
QKVParallelLinearWithLora
,
VocabParallelEmbeddingWithLoRA
,
VocabParallelEmbeddingWithLoRA
,
RowParallelLinearWithLoRA
,
RowParallelLinearWithLoRA
,
Sample
rWithLoRA
,
LogitsProcesso
rWithLoRA
,
LoRAMapping
,
LoRAMapping
,
BaseLayerWithLoRA
,
BaseLayerWithLoRA
,
)
)
from
vllm.lora.models
import
(
LoRALayerWeights
,
convert_mapping
,
from
vllm.lora.models
import
(
LoRALayerWeights
,
convert_mapping
,
PackedLoRALayerWeights
)
PackedLoRALayerWeights
)
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.model_executor.layers.
sampler
import
Sample
r
from
vllm.model_executor.layers.
logits_processor
import
LogitsProcesso
r
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -394,7 +394,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
...
@@ -394,7 +394,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device) -> None:
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"num_loras"
,
[
1
,
2
,
4
,
8
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_lm_head_
sample
r
(
dist_init
,
num_loras
,
device
)
->
None
:
def
test_lm_head_
logits_processo
r
(
dist_init
,
num_loras
,
device
)
->
None
:
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
max_loras
=
8
max_loras
=
8
...
@@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
...
@@ -402,28 +402,29 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
max_lora_rank
=
8
,
max_lora_rank
=
8
,
lora_dtype
=
torch
.
float16
)
lora_dtype
=
torch
.
float16
)
def
c
re
a
te
_random_sampler_layer
():
def
_p
rete
st
():
linear
=
ParallelLMHead
(
32000
+
lora_config
.
lora_extra_vocab_size
,
linear
=
ParallelLMHead
(
32000
+
lora_config
.
lora_extra_vocab_size
,
1024
,
32000
)
1024
,
32000
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
=
torch
.
rand_like
(
linear
.
weight
.
data
)
linear
.
weight
.
data
[:,
32000
:]
=
0
linear
.
weight
.
data
[:,
32000
:]
=
0
sampler
=
Sampler
(
32000
+
lora_config
.
lora_extra_vocab_size
,
32000
)
logits_processor
=
LogitsProcessor
(
lora_sampler
=
SamplerWithLoRA
(
sampler
,
1024
,
linear
.
weight
.
dtype
,
32000
+
lora_config
.
lora_extra_vocab_size
,
32000
)
linear
.
weight
.
device
)
lora_logits_processor
=
LogitsProcessorWithLoRA
(
lora_sampler
.
create_lora_weights
(
max_loras
,
lora_config
)
logits_processor
,
1024
,
linear
.
weight
.
dtype
,
linear
.
weight
.
device
)
lora_logits_processor
.
create_lora_weights
(
max_loras
,
lora_config
)
return
linear
,
sampler
,
lora_sample
r
return
linear
,
logits_processor
,
lora_logits_processo
r
for
i
in
range
(
10
):
for
i
in
range
(
10
):
set_random_seed
(
i
)
set_random_seed
(
i
)
id_to_index
=
get_random_id_to_index
(
num_loras
,
max_loras
)
id_to_index
=
get_random_id_to_index
(
num_loras
,
max_loras
)
linear
,
sampler
,
lora_sample
r
=
c
re
a
te
_random_sampler_layer
()
linear
,
logits_processor
,
lora_logits_processo
r
=
_p
rete
st
()
# NOTE: all the generated loras share the same embeddings tensor.
# NOTE: all the generated loras share the same embeddings tensor.
lora_dict
,
_
=
populate_loras
(
lora_dict
,
_
=
populate_loras
(
id_to_index
,
id_to_index
,
layer
=
lora_
sample
r
,
layer
=
lora_
logits_processo
r
,
layer_weights
=
linear
.
weight
,
layer_weights
=
linear
.
weight
,
generate_embeddings_tensor
=
1024
,
generate_embeddings_tensor
=
1024
,
)
)
...
@@ -447,34 +448,37 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
...
@@ -447,34 +448,37 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
32000
,
32000
,
lora_config
.
lora_extra_vocab_size
,
lora_config
.
lora_extra_vocab_size
,
)
)
lora_
sample
r
.
set_mapping
(
*
mapping_info
,
)
lora_
logits_processo
r
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_sampler
.
_get_logits
(
hidden_states
=
torch
.
cat
(
inputs
),
lora_result
=
lora_logits_processor
.
_get_logits
(
embedding
=
linear
.
weight
,
hidden_states
=
torch
.
cat
(
inputs
),
embedding_bias
=
None
)
embedding
=
linear
.
weight
,
embedding_bias
=
None
)
original_weight
=
linear
.
weight
.
clone
()
original_weight
=
linear
.
weight
.
clone
()
linear
.
weight
[
sampler
.
org_vocab_size
:
sampler
.
org_vocab_size
+
linear
.
weight
[
logits_processor
.
org_vocab_size
:
logits_processor
.
org_vocab_size
+
embeddings_tensor_len
]
=
embeddings_tensor
embeddings_tensor_len
]
=
embeddings_tensor
sampler
.
org_vocab_size
=
32000
+
lora_config
.
lora_extra_vocab_size
logits_processor
.
org_vocab_size
=
(
32000
+
lora_config
.
lora_extra_vocab_size
)
expected_results
=
[]
expected_results
=
[]
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
for
input_
,
lora_id
in
zip
(
inputs
,
prompt_mapping
):
lora
=
lora_dict
[
lora_id
]
lora
=
lora_dict
[
lora_id
]
result
=
sample
r
.
_get_logits
(
hidden_states
=
input_
,
result
=
logits_processo
r
.
_get_logits
(
hidden_states
=
input_
,
embedding
=
linear
.
weight
,
embedding
=
linear
.
weight
,
embedding_bias
=
None
)
embedding_bias
=
None
)
result
[:,
32000
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
[:,
32000
+
embeddings_tensor_len
:]
=
float
(
"-inf"
)
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
result
+=
input_
@
lora
.
lora_a
@
lora
.
lora_b
*
lora
.
scaling
expected_results
.
append
(
result
)
expected_results
.
append
(
result
)
expected_result
=
torch
.
cat
(
expected_results
)
expected_result
=
torch
.
cat
(
expected_results
)
sample
r
.
org_vocab_size
=
32000
logits_processo
r
.
org_vocab_size
=
32000
# Check that resetting the lora weights succeeds
# Check that resetting the lora weights succeeds
for
slot_idx
in
range
(
max_loras
):
for
slot_idx
in
range
(
max_loras
):
lora_
sample
r
.
reset_lora
(
slot_idx
)
lora_
logits_processo
r
.
reset_lora
(
slot_idx
)
inputs
,
index_mapping
,
prompt_mapping
=
create_random_inputs
(
inputs
,
index_mapping
,
prompt_mapping
=
create_random_inputs
(
active_lora_ids
=
[
0
],
active_lora_ids
=
[
0
],
...
@@ -488,14 +492,16 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
...
@@ -488,14 +492,16 @@ def test_lm_head_sampler(dist_init, num_loras, device) -> None:
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
mapping_info
=
convert_mapping
(
lora_mapping
,
id_to_index
,
max_loras
,
32000
,
32000
,
lora_config
.
lora_extra_vocab_size
)
lora_config
.
lora_extra_vocab_size
)
lora_sampler
.
set_mapping
(
*
mapping_info
,
)
lora_logits_processor
.
set_mapping
(
*
mapping_info
,
)
lora_result
=
lora_sampler
.
_get_logits
(
hidden_states
=
torch
.
cat
(
inputs
),
lora_result
=
lora_logits_processor
.
_get_logits
(
embedding
=
original_weight
,
hidden_states
=
torch
.
cat
(
inputs
),
embedding_bias
=
None
)[:,
:
32000
]
embedding
=
original_weight
,
expected_result
=
sampler
.
_get_logits
(
hidden_states
=
torch
.
cat
(
inputs
),
embedding_bias
=
None
)[:,
:
32000
]
embedding
=
original_weight
,
expected_result
=
logits_processor
.
_get_logits
(
embedding_bias
=
None
)
hidden_states
=
torch
.
cat
(
inputs
),
embedding
=
original_weight
,
embedding_bias
=
None
)
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
rtol
,
atol
=
TOLERANCES
[
lora_result
.
dtype
]
assert
torch
.
allclose
(
lora_result
,
assert
torch
.
allclose
(
lora_result
,
...
...
tests/samplers/test_sampler.py
View file @
f1c0fc39
...
@@ -15,17 +15,12 @@ from vllm.worker.model_runner import ModelRunner
...
@@ -15,17 +15,12 @@ from vllm.worker.model_runner import ModelRunner
class
MockLogitsSampler
(
Sampler
):
class
MockLogitsSampler
(
Sampler
):
def
__init__
(
self
,
vocab_size
:
int
,
fake_logits
:
torch
.
Tensor
):
def
__init__
(
self
,
fake_logits
:
torch
.
Tensor
):
super
().
__init__
(
vocab_size
=
vocab_size
)
super
().
__init__
()
self
.
fake_logits
=
fake_logits
self
.
fake_logits
=
fake_logits
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
with
patch
(
return
super
().
forward
(
*
args
,
**
kwargs
)
"vllm.model_executor.layers.sampler._prune_hidden_states"
,
lambda
x
,
y
:
x
),
patch
(
"vllm.model_executor.layers.sampler.Sampler._get_logits"
,
lambda
*
args
,
**
kwargs
:
self
.
fake_logits
):
return
super
().
forward
(
*
args
,
**
kwargs
)
def
_prepare_test
(
def
_prepare_test
(
...
@@ -36,7 +31,7 @@ def _prepare_test(
...
@@ -36,7 +31,7 @@ def _prepare_test(
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
1e-2
,
1e-2
,
dtype
=
input_tensor
.
dtype
)
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
sampler
=
MockLogitsSampler
(
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
return
input_tensor
,
fake_logits
,
sampler
,
model_runner
return
input_tensor
,
fake_logits
,
sampler
,
model_runner
...
@@ -70,9 +65,7 @@ def _do_sample(
...
@@ -70,9 +65,7 @@ def _do_sample(
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
subquery_lens
=
prompt_lens
)
return
sampler
(
embedding
=
None
,
return
sampler
(
logits
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
...
@@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
...
@@ -85,8 +78,8 @@ def test_sampler_all_greedy(seed: int, device: str):
batch_size
)
batch_size
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampling_params
=
SamplingParams
(
temperature
=
0
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampl
er
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runn
er
,
model_runner
,
sampling_params
)
sampling_params
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
expected
=
torch
.
argmax
(
fake_logits
,
dim
=-
1
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
...
@@ -111,8 +104,8 @@ def test_sampler_all_random(seed: int, device: str):
temperature
=
1.0
,
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
n
=
random
.
randint
(
1
,
10
),
)
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampl
er
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runn
er
,
model_runner
,
sampling_params
)
sampling_params
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
...
@@ -127,8 +120,7 @@ def test_sampler_all_random_seed(seed: int, device: str):
set_random_seed
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
batch_size
)
for
i
in
range
(
batch_size
):
for
i
in
range
(
batch_size
):
fake_logits
[
i
,
i
]
=
1e2
fake_logits
[
i
,
i
]
=
1e2
...
@@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
...
@@ -138,8 +130,8 @@ def test_sampler_all_random_seed(seed: int, device: str):
n
=
random
.
randint
(
1
,
10
),
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
seed
=
random
.
randint
(
0
,
10000
),
)
)
sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampl
er
,
sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runn
er
,
model_runner
,
sampling_params
)
sampling_params
)
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
i
,
sequence_output
in
enumerate
(
sampler_output
):
for
nth_output
in
sequence_output
.
samples
:
for
nth_output
in
sequence_output
.
samples
:
...
@@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
...
@@ -154,18 +146,17 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str):
set_random_seed
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
batch_size
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
1.0
,
temperature
=
1.0
,
n
=
random
.
randint
(
1
,
10
),
n
=
random
.
randint
(
1
,
10
),
seed
=
random
.
randint
(
0
,
10000
),
seed
=
random
.
randint
(
0
,
10000
),
)
)
first_sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
first_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
model_runner
,
sampling_params
)
second_sampler_output
=
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
second_sampler_output
=
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
model_runner
,
sampling_params
)
assert
first_sampler_output
==
second_sampler_output
assert
first_sampler_output
==
second_sampler_output
...
@@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
...
@@ -179,15 +170,14 @@ def test_sampler_all_beam(seed: int, device: str):
set_random_seed
(
seed
)
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
_
,
fake_logits
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
sampling_params
=
SamplingParams
(
sampling_params
=
SamplingParams
(
temperature
=
0
,
temperature
=
0
,
best_of
=
2
,
best_of
=
2
,
use_beam_search
=
True
,
use_beam_search
=
True
,
)
)
_do_sample
(
batch_size
,
input_tensor
,
sampler
,
model_runner
,
_do_sample
(
batch_size
,
fake_logits
,
sampler
,
model_runner
,
sampling_params
)
sampling_params
)
# no assertion here as I am not sure how to determine whether
# no assertion here as I am not sure how to determine whether
# the outputs are expected - in other words, this just tests
# the outputs are expected - in other words, this just tests
# whether there are no exceptions in the sampler
# whether there are no exceptions in the sampler
...
@@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -246,8 +236,7 @@ def test_sampler_mixed(seed: int, device: str):
def
test_sampling
(
model_runner
:
ModelRunner
):
def
test_sampling
(
model_runner
:
ModelRunner
):
sampling_metadata
=
model_runner
.
_prepare_sample
(
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
sampler_output
=
sampler
(
logits
=
fake_logits
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
sampling_metadata
=
sampling_metadata
)
for
i
,
(
sequence_output
,
metadata
)
in
enumerate
(
for
i
,
(
sequence_output
,
metadata
)
in
enumerate
(
...
@@ -294,48 +283,6 @@ def test_sampler_mixed(seed: int, device: str):
...
@@ -294,48 +283,6 @@ def test_sampler_mixed(seed: int, device: str):
del
model_runner
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_logits_processors
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
_
,
sampler
,
model_runner
=
_prepare_test
(
batch_size
)
# This sample logits processor gives maximum score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def
pick_ith
(
token_ids
,
logits
):
logits
[
len
(
token_ids
)]
=
torch
.
finfo
(
logits
.
dtype
).
max
return
logits
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
sampler_output
=
sampler
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
for
_
,
sequence_output
in
enumerate
(
sampler_output
):
for
idx
,
nth_output
in
enumerate
(
sequence_output
.
samples
):
assert
nth_output
.
output_token
==
idx
del
model_runner
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_sampler_top_k_top_p
(
seed
:
int
,
device
:
str
):
def
test_sampler_top_k_top_p
(
seed
:
int
,
device
:
str
):
...
@@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -352,7 +299,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
size
=
(
batch_size
,
vocab_size
),
size
=
(
batch_size
,
vocab_size
),
device
=
input_tensor
.
device
,
device
=
input_tensor
.
device
,
dtype
=
input_tensor
.
dtype
)
dtype
=
input_tensor
.
dtype
)
sampler
=
MockLogitsSampler
(
32000
,
fake_logits
)
sampler
=
MockLogitsSampler
(
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
generation_model
=
GenerationMixin
()
generation_model
=
GenerationMixin
()
...
@@ -391,9 +338,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
...
@@ -391,9 +338,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
return
[[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
]
return
[[
prob
.
topk
(
1
,
dim
=-
1
).
indices
.
tolist
(),
[
0
]]
for
prob
in
probs
]
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
with
patch
(
"vllm.model_executor.layers.sampler._sample"
,
mock_sample
):
sampler
(
embedding
=
None
,
sampler
(
logits
=
fake_logits
,
sampling_metadata
=
sampling_metadata
)
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
warpers
(
torch
.
zeros_like
(
fake_logits
),
fake_logits
.
clone
())
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
hf_probs
=
torch
.
softmax
(
hf_probs
,
dim
=-
1
,
dtype
=
torch
.
float
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
assert
torch
.
allclose
(
hf_probs
,
sample_probs
,
atol
=
1e-5
)
...
...
tests/test_logits_processor.py
0 → 100644
View file @
f1c0fc39
import
random
from
typing
import
Tuple
from
unittest.mock
import
patch
import
pytest
import
torch
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.sequence
import
SamplingParams
,
SequenceData
,
SequenceGroupMetadata
from
vllm.worker.model_runner
import
ModelRunner
class
MockLogitsProcessor
(
LogitsProcessor
):
def
__init__
(
self
,
vocab_size
:
int
,
scale
:
float
,
fake_logits
:
torch
.
Tensor
):
super
().
__init__
(
vocab_size
=
vocab_size
,
scale
=
scale
)
self
.
fake_logits
=
fake_logits
.
clone
()
def
forward
(
self
,
*
args
,
**
kwargs
):
with
patch
(
"vllm.model_executor.layers.logits_processor._prune_hidden_states"
,
lambda
x
,
y
:
x
),
patch
(
"vllm.model_executor.layers.logits_processor.LogitsProcessor._get_logits"
,
lambda
*
args
,
**
kwargs
:
self
.
fake_logits
):
return
super
().
forward
(
*
args
,
**
kwargs
)
def
_prepare_test
(
batch_size
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
MockLogitsProcessor
,
ModelRunner
]:
vocab_size
=
32000
input_tensor
=
torch
.
rand
((
batch_size
,
1024
),
dtype
=
torch
.
float16
)
fake_logits
=
torch
.
full
((
batch_size
,
vocab_size
),
1e-2
,
dtype
=
input_tensor
.
dtype
)
logits_processor
=
MockLogitsProcessor
(
32000
,
0.5
,
fake_logits
)
model_runner
=
ModelRunner
(
None
,
None
,
None
,
None
,
None
)
return
input_tensor
,
fake_logits
,
logits_processor
,
model_runner
RANDOM_SEEDS
=
list
(
range
(
128
))
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
@
pytest
.
mark
.
parametrize
(
"seed"
,
RANDOM_SEEDS
)
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_logits_processors
(
seed
:
int
,
device
:
str
):
set_random_seed
(
seed
)
torch
.
set_default_device
(
device
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
logits_processor
,
model_runner
=
_prepare_test
(
batch_size
)
# This sample logits processor gives infinite score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def
pick_ith
(
token_ids
,
logits
):
logits
[
len
(
token_ids
)]
=
float
(
"inf"
)
return
logits
seq_group_metadata_list
=
[]
prompt_lens
=
[]
for
i
in
range
(
batch_size
):
seq_group_metadata_list
.
append
(
SequenceGroupMetadata
(
request_id
=
f
"test_
{
i
}
"
,
is_prompt
=
True
,
seq_data
=
{
0
:
SequenceData
([
1
,
2
,
3
])},
sampling_params
=
SamplingParams
(
temperature
=
0
,
logits_processors
=
[
pick_ith
]),
block_tables
=
{
0
:
[
1
]},
))
prompt_lens
.
append
(
seq_group_metadata_list
[
-
1
].
seq_data
[
0
].
get_len
())
sampling_metadata
=
model_runner
.
_prepare_sample
(
seq_group_metadata_list
,
prompt_lens
,
subquery_lens
=
prompt_lens
)
logits_processor_output
=
logits_processor
(
embedding
=
None
,
hidden_states
=
input_tensor
,
sampling_metadata
=
sampling_metadata
)
assert
torch
.
isinf
(
logits_processor_output
[:,
0
]).
all
()
fake_logits
*=
logits_processor
.
scale
assert
torch
.
allclose
(
logits_processor_output
[:,
1
],
fake_logits
[:,
1
],
1e-4
)
del
model_runner
vllm/lora/layers.py
View file @
f1c0fc39
...
@@ -10,7 +10,6 @@ from transformers import PretrainedConfig
...
@@ -10,7 +10,6 @@ from transformers import PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.lora.punica
import
add_lora
,
add_lora_slice
,
bgmv
from
vllm.lora.punica
import
add_lora
,
add_lora_slice
,
bgmv
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
...
@@ -20,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -20,6 +19,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
,
RowParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
MergedColumnParallelLinear
)
MergedColumnParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
...
@@ -783,11 +783,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -783,11 +783,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
return
self
.
base_layer
.
weight
return
self
.
base_layer
.
weight
class
Sample
rWithLoRA
(
BaseLayerWithLoRA
):
class
LogitsProcesso
rWithLoRA
(
BaseLayerWithLoRA
):
def
__init__
(
def
__init__
(
self
,
self
,
base_layer
:
Sample
r
,
base_layer
:
LogitsProcesso
r
,
hidden_size
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
device
:
torch
.
device
,
...
@@ -806,6 +806,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
...
@@ -806,6 +806,10 @@ class SamplerWithLoRA(BaseLayerWithLoRA):
def
vocab_size
(
self
):
def
vocab_size
(
self
):
return
self
.
base_layer
.
vocab_size
return
self
.
base_layer
.
vocab_size
@
property
def
scale
(
self
):
return
self
.
base_layer
.
scale
@
property
@
property
def
org_vocab_size
(
self
):
def
org_vocab_size
(
self
):
return
self
.
base_layer
.
org_vocab_size
return
self
.
base_layer
.
org_vocab_size
...
@@ -968,14 +972,14 @@ def from_layer(
...
@@ -968,14 +972,14 @@ def from_layer(
return
layer
return
layer
def
from_layer_
sample
r
(
def
from_layer_
logits_processo
r
(
layer
:
Sample
r
,
layer
:
LogitsProcesso
r
,
lm_head
:
ParallelLMHead
,
lm_head
:
ParallelLMHead
,
max_loras
:
int
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
Sample
rWithLoRA
:
)
->
LogitsProcesso
rWithLoRA
:
ret
=
Sample
rWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
ret
=
LogitsProcesso
rWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
device
)
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
return
ret
vllm/lora/models.py
View file @
f1c0fc39
...
@@ -14,7 +14,7 @@ from vllm.config import LoRAConfig
...
@@ -14,7 +14,7 @@ from vllm.config import LoRAConfig
from
vllm.utils
import
LRUCache
,
in_wsl
from
vllm.utils
import
LRUCache
,
in_wsl
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
LoRAMapping
,
from_layer
,
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
LoRAMapping
,
from_layer
,
from_layer_
sample
r
)
from_layer_
logits_processo
r
)
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.utils
import
parse_fine_tuned_lora_name
,
replace_submodule
from
vllm.lora.utils
import
parse_fine_tuned_lora_name
,
replace_submodule
...
@@ -421,11 +421,14 @@ class LoRAModelManager:
...
@@ -421,11 +421,14 @@ class LoRAModelManager:
self
.
model
.
config
))
self
.
model
.
config
))
# (yard1): TODO make this more robust
# (yard1): TODO make this more robust
if
"lm_head"
in
module_name
:
if
"lm_head"
in
module_name
:
sampler_module
=
self
.
model
.
get_submodule
(
"sampler"
)
logits_processor_module
=
self
.
model
.
get_submodule
(
"logits_processor"
)
new_module
=
replace_submodule
(
new_module
=
replace_submodule
(
self
.
model
,
"sampler"
,
self
.
model
,
"logits_processor"
,
from_layer_sampler
(
sampler_module
,
module
,
self
.
lora_slots
,
from_layer_logits_processor
(
logits_processor_module
,
self
.
lora_config
,
self
.
model
.
config
))
module
,
self
.
lora_slots
,
self
.
lora_config
,
self
.
model
.
config
))
self
.
register_module
(
module_name
,
new_module
)
self
.
register_module
(
module_name
,
new_module
)
self
.
_register_packed_modules
(
module_name
)
self
.
_register_packed_modules
(
module_name
)
new_module
.
set_mapping
(
self
.
base_indices
,
self
.
sampler_indices
,
new_module
.
set_mapping
(
self
.
base_indices
,
self
.
sampler_indices
,
...
...
vllm/model_executor/layers/logits_processor.py
0 → 100644
View file @
f1c0fc39
"""A layer that compute logits from hidden_stats."""
from
typing
import
Optional
import
torch
import
torch.nn
as
nn
from
vllm.utils
import
is_neuron
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_gather
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
class
LogitsProcessor
(
nn
.
Module
):
"""Process logits and apply logits processors from sampling metadata.
This layer does the following:
1. Gather logits from model hidden_states.
2. Scale logits if needed.
3. Apply logits processors (if any).
"""
def
__init__
(
self
,
vocab_size
:
int
,
org_vocab_size
:
Optional
[
int
]
=
None
,
scale
:
Optional
[
float
]
=
1.0
)
->
None
:
"""
Args:
scale: A scaling factor to apply to the logits.
"""
super
().
__init__
()
self
.
scale
=
scale
self
.
vocab_size
=
vocab_size
# Transformers-neuronx generate outputs as logits directly.
self
.
logits_as_hidden_states
=
is_neuron
()
# original vocabulary size (without LoRA).
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
def
forward
(
self
,
embedding
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
self
.
logits_as_hidden_states
:
logits
=
hidden_states
else
:
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
if
logits
is
not
None
:
logits
*=
self
.
scale
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
return
logits
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
embedding_bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
tensor_model_parallel_gather
(
logits
)
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
logits
=
logits
[:,
:
self
.
org_vocab_size
]
return
logits
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
return
hidden_states
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
)
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
logits_row_idx
=
0
found_logits_processors
=
False
for
seq_ids
,
sampling_params
in
sampling_metadata
.
seq_groups
:
logits_processors
=
sampling_params
.
logits_processors
if
logits_processors
:
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
sampling_metadata
.
seq_data
[
seq_id
].
output_token_ids
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits_row_idx
+=
1
else
:
logits_row_idx
+=
len
(
seq_ids
)
if
found_logits_processors
:
assert
logits_row_idx
==
logits
.
shape
[
0
]
return
logits
vllm/model_executor/layers/sampler.py
View file @
f1c0fc39
...
@@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple
...
@@ -4,8 +4,6 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_gather
)
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
from
vllm.model_executor.sampling_metadata
import
(
SamplingMetadata
,
SamplingTensors
)
SamplingTensors
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
...
@@ -13,7 +11,6 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
...
@@ -13,7 +11,6 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceOutput
)
from
vllm.model_executor.layers.ops.sample
import
(
sample
as
sample_triton
)
from
vllm.model_executor.layers.ops.sample
import
(
sample
as
sample_triton
)
from
vllm.utils
import
is_neuron
class
Sampler
(
nn
.
Module
):
class
Sampler
(
nn
.
Module
):
...
@@ -31,58 +28,14 @@ class Sampler(nn.Module):
...
@@ -31,58 +28,14 @@ class Sampler(nn.Module):
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
"""
"""
def
__init__
(
self
,
vocab_size
:
int
,
org_vocab_size
:
Optional
[
int
]
=
None
)
->
None
:
super
().
__init__
()
self
.
vocab_size
=
vocab_size
# Transformers-neuronx generate outputs as logits directly.
self
.
logits_as_hidden_states
=
is_neuron
()
# original vocabulary size (without LoRA).
self
.
org_vocab_size
=
org_vocab_size
or
vocab_size
def
_get_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
embedding_bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
# Get the logits for the next tokens.
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
tensor_model_parallel_gather
(
logits
)
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
logits
=
logits
[:,
:
self
.
org_vocab_size
]
return
logits
def
forward
(
def
forward
(
self
,
self
,
embedding
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
# Get the hidden states that we use for sampling.
if
self
.
logits_as_hidden_states
:
logits
=
hidden_states
else
:
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling_metadata
)
# Get the logits for the next tokens.
logits
=
self
.
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
)
# Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because
# the `embedding` weight is distributed across TP workers.
# TODO(zhuohan): Change the get_logits part to a separate stage.
if
not
sampling_metadata
.
perform_sampling
:
return
None
assert
logits
is
not
None
assert
logits
is
not
None
_
,
vocab_size
=
logits
.
shape
_
,
vocab_size
=
logits
.
shape
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
# Prepare sampling tensors with pinned memory to avoid blocking.
# Prepare sampling tensors with pinned memory to avoid blocking.
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
(
sampling_tensors
,
do_penalties
,
do_top_p_top_k
,
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
do_min_p
)
=
SamplingTensors
.
from_sampling_metadata
(
...
@@ -124,14 +77,6 @@ class Sampler(nn.Module):
...
@@ -124,14 +77,6 @@ class Sampler(nn.Module):
prompt_logprobs
,
sample_logprobs
)
prompt_logprobs
,
sample_logprobs
)
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
return
hidden_states
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
)
def
_get_bin_counts_and_mask
(
def
_get_bin_counts_and_mask
(
tokens
:
torch
.
Tensor
,
tokens
:
torch
.
Tensor
,
vocab_size
:
int
,
vocab_size
:
int
,
...
@@ -149,30 +94,6 @@ def _get_bin_counts_and_mask(
...
@@ -149,30 +94,6 @@ def _get_bin_counts_and_mask(
return
bin_counts
,
mask
return
bin_counts
,
mask
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
logits_row_idx
=
0
found_logits_processors
=
False
for
seq_ids
,
sampling_params
in
sampling_metadata
.
seq_groups
:
logits_processors
=
sampling_params
.
logits_processors
if
logits_processors
:
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
sampling_metadata
.
seq_data
[
seq_id
].
output_token_ids
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits_row_idx
+=
1
else
:
logits_row_idx
+=
len
(
seq_ids
)
if
found_logits_processors
:
assert
logits_row_idx
==
logits
.
shape
[
0
]
return
logits
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
prompt_tokens_tensor
:
torch
.
Tensor
,
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
prompt_tokens_tensor
:
torch
.
Tensor
,
output_tokens_tensor
:
torch
.
Tensor
,
output_tokens_tensor
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
presence_penalties
:
torch
.
Tensor
,
...
...
vllm/model_executor/models/baichuan.py
View file @
f1c0fc39
...
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -295,7 +296,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -295,7 +296,8 @@ class BaiChuanBaseForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
linear_method
)
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -308,13 +310,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -308,13 +310,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/bloom.py
View file @
f1c0fc39
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -273,7 +274,8 @@ class BloomForCausalLM(nn.Module):
...
@@ -273,7 +274,8 @@ class BloomForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
BloomModel
(
config
,
linear_method
)
self
.
transformer
=
BloomModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
word_embeddings
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -286,13 +288,18 @@ class BloomForCausalLM(nn.Module):
...
@@ -286,13 +288,18 @@ class BloomForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/chatglm.py
View file @
f1c0fc39
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -17,6 +17,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -332,7 +333,8 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -332,7 +333,8 @@ class ChatGLMForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
ChatGLMModel
(
config
,
linear_method
)
self
.
transformer
=
ChatGLMModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
output_layer
.
weight
self
.
sampler
=
Sampler
(
config
.
padded_vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
padded_vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -345,13 +347,18 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -345,13 +347,18 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/deepseek.py
View file @
f1c0fc39
...
@@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -38,6 +38,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -372,7 +373,8 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -372,7 +373,8 @@ class DeepseekForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
DeepseekModel
(
config
,
linear_method
)
self
.
model
=
DeepseekModel
(
config
,
linear_method
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -385,13 +387,18 @@ class DeepseekForCausalLM(nn.Module):
...
@@ -385,13 +387,18 @@ class DeepseekForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
Optional
[
torch
.
Tensor
],
logit
s
:
Optional
[
torch
.
Tensor
],
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/falcon.py
View file @
f1c0fc39
...
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -34,6 +34,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -373,7 +374,8 @@ class FalconForCausalLM(nn.Module):
...
@@ -373,7 +374,8 @@ class FalconForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -390,13 +392,18 @@ class FalconForCausalLM(nn.Module):
...
@@ -390,13 +392,18 @@ class FalconForCausalLM(nn.Module):
)
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gemma.py
View file @
f1c0fc39
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -281,7 +282,8 @@ class GemmaForCausalLM(nn.Module):
...
@@ -281,7 +282,8 @@ class GemmaForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
GemmaModel
(
config
,
linear_method
)
self
.
model
=
GemmaModel
(
config
,
linear_method
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -295,13 +297,18 @@ class GemmaForCausalLM(nn.Module):
...
@@ -295,13 +297,18 @@ class GemmaForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
model
.
embed_tokens
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
model
.
embed_tokens
.
weight
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
hidden_states
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt2.py
View file @
f1c0fc39
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -216,7 +217,8 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -216,7 +217,8 @@ class GPT2LMHeadModel(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
GPT2Model
(
config
,
linear_method
)
self
.
transformer
=
GPT2Model
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -229,12 +231,18 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -229,12 +231,18 @@ class GPT2LMHeadModel(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_state
s
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
logit
s
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
f1c0fc39
...
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -31,6 +31,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearMethodBase
,
LinearMethodBase
,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
...
@@ -237,7 +238,8 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -237,7 +238,8 @@ class GPTBigCodeForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
transformer
=
GPTBigCodeModel
(
config
,
linear_method
)
self
.
transformer
=
GPTBigCodeModel
(
config
,
linear_method
)
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
lm_head_weight
=
self
.
transformer
.
wte
.
weight
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -250,13 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -250,13 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head_weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_j.py
View file @
f1c0fc39
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -224,7 +225,8 @@ class GPTJForCausalLM(nn.Module):
...
@@ -224,7 +225,8 @@ class GPTJForCausalLM(nn.Module):
config
.
n_embd
,
config
.
n_embd
,
bias
=
True
,
bias
=
True
,
)
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -237,13 +239,18 @@ class GPTJForCausalLM(nn.Module):
...
@@ -237,13 +239,18 @@ class GPTJForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
,
self
.
lm_head
.
bias
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
,
self
.
lm_head
.
bias
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
f1c0fc39
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -30,6 +30,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -238,7 +239,8 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -238,7 +239,8 @@ class GPTNeoXForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
)
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -251,13 +253,18 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -251,13 +253,18 @@ class GPTNeoXForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
embed_out
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/internlm2.py
View file @
f1c0fc39
...
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
...
@@ -14,6 +14,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase,
QKVParallelLinear
,
QKVParallelLinear
,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.rotary_embedding
import
get_rope
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
...
@@ -250,7 +251,8 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -250,7 +251,8 @@ class InternLM2ForCausalLM(nn.Module):
self
.
linear_method
=
linear_method
self
.
linear_method
=
linear_method
self
.
model
=
InternLM2Model
(
config
,
linear_method
)
self
.
model
=
InternLM2Model
(
config
,
linear_method
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
output
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
.
vocab_size
)
self
.
sampler
=
Sampler
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -263,13 +265,18 @@ class InternLM2ForCausalLM(nn.Module):
...
@@ -263,13 +265,18 @@ class InternLM2ForCausalLM(nn.Module):
input_metadata
)
input_metadata
)
return
hidden_states
return
hidden_states
def
compute_logits
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
)
->
torch
.
Tensor
:
logits
=
self
.
logits_processor
(
self
.
output
.
weight
,
hidden_states
,
sampling_metadata
)
return
logits
def
sample
(
def
sample
(
self
,
self
,
hidden_state
s
:
torch
.
Tensor
,
logit
s
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
Optional
[
SamplerOutput
]:
)
->
Optional
[
SamplerOutput
]:
next_tokens
=
self
.
sampler
(
self
.
output
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
logits
,
sampling_metadata
)
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment