Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
27feead2
Unverified
Commit
27feead2
authored
Nov 29, 2023
by
Woosuk Kwon
Committed by
GitHub
Nov 29, 2023
Browse files
Refactor Worker & InputMetadata (#1843)
parent
c7821956
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
221 additions
and
163 deletions
+221
-163
vllm/config.py
vllm/config.py
+6
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+4
-3
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+0
-2
vllm/model_executor/__init__.py
vllm/model_executor/__init__.py
+2
-0
vllm/model_executor/input_metadata.py
vllm/model_executor/input_metadata.py
+16
-65
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+3
-11
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+60
-55
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+10
-2
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+10
-2
vllm/model_executor/models/bloom.py
vllm/model_executor/models/bloom.py
+10
-2
vllm/model_executor/models/chatglm.py
vllm/model_executor/models/chatglm.py
+10
-2
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+10
-3
vllm/model_executor/models/gpt2.py
vllm/model_executor/models/gpt2.py
+10
-2
vllm/model_executor/models/gpt_bigcode.py
vllm/model_executor/models/gpt_bigcode.py
+10
-2
vllm/model_executor/models/gpt_j.py
vllm/model_executor/models/gpt_j.py
+10
-2
vllm/model_executor/models/gpt_neox.py
vllm/model_executor/models/gpt_neox.py
+10
-2
vllm/model_executor/models/internlm.py
vllm/model_executor/models/internlm.py
+10
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+10
-2
vllm/model_executor/models/mistral.py
vllm/model_executor/models/mistral.py
+10
-2
vllm/model_executor/models/mpt.py
vllm/model_executor/models/mpt.py
+10
-2
No files found.
vllm/config.py
View file @
27feead2
...
@@ -161,6 +161,12 @@ class ModelConfig:
...
@@ -161,6 +161,12 @@ class ModelConfig:
"must be divisible by pipeline parallel size "
"must be divisible by pipeline parallel size "
f
"(
{
pipeline_parallel_size
}
)."
)
f
"(
{
pipeline_parallel_size
}
)."
)
def
get_sliding_window
(
self
)
->
Optional
[
int
]:
return
getattr
(
self
.
hf_config
,
"sliding_window"
,
None
)
def
get_vocab_size
(
self
)
->
int
:
return
self
.
hf_config
.
vocab_size
def
get_hidden_size
(
self
)
->
int
:
def
get_hidden_size
(
self
)
->
int
:
return
self
.
hf_config
.
hidden_size
return
self
.
hf_config
.
hidden_size
...
...
vllm/engine/arg_utils.py
View file @
27feead2
...
@@ -201,9 +201,10 @@ class EngineArgs:
...
@@ -201,9 +201,10 @@ class EngineArgs:
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
dtype
,
self
.
seed
,
self
.
revision
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
tokenizer_revision
,
self
.
max_model_len
,
self
.
quantization
)
self
.
quantization
)
cache_config
=
CacheConfig
(
cache_config
=
CacheConfig
(
self
.
block_size
,
self
.
block_size
,
self
.
gpu_memory_utilization
,
self
.
swap_space
,
self
.
gpu_memory_utilization
,
getattr
(
model_config
.
hf_config
,
'sliding_window'
,
None
))
self
.
swap_space
,
model_config
.
get_sliding_window
())
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
parallel_config
=
ParallelConfig
(
self
.
pipeline_parallel_size
,
self
.
tensor_parallel_size
,
self
.
tensor_parallel_size
,
self
.
worker_use_ray
,
self
.
worker_use_ray
,
...
...
vllm/engine/llm_engine.py
View file @
27feead2
...
@@ -88,8 +88,6 @@ class LLMEngine:
...
@@ -88,8 +88,6 @@ class LLMEngine:
self
.
model_config
=
model_config
self
.
model_config
=
model_config
self
.
cache_config
=
cache_config
self
.
cache_config
=
cache_config
assert
self
.
cache_config
.
sliding_window
==
getattr
(
self
.
model_config
.
hf_config
,
"sliding_window"
,
None
)
self
.
parallel_config
=
parallel_config
self
.
parallel_config
=
parallel_config
self
.
scheduler_config
=
scheduler_config
self
.
scheduler_config
=
scheduler_config
self
.
log_stats
=
log_stats
self
.
log_stats
=
log_stats
...
...
vllm/model_executor/__init__.py
View file @
27feead2
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.utils
import
set_random_seed
from
vllm.model_executor.utils
import
set_random_seed
__all__
=
[
__all__
=
[
"InputMetadata"
,
"InputMetadata"
,
"get_model"
,
"get_model"
,
"SamplingMetadata"
,
"set_random_seed"
,
"set_random_seed"
,
]
]
vllm/model_executor/input_metadata.py
View file @
27feead2
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
import
torch
import
torch
from
xformers.ops
import
AttentionBias
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SequenceData
class
InputMetadata
:
class
InputMetadata
:
"""Metadata for input sequences. Used
for
PagedAttention.
"""Metadata for input sequences. Used
in
PagedAttention.
Args:
Args:
seq_groups: List of (seq_ids, sampling_params).
seq_data: Seq_id -> SequenceData.
prompt_lens: Lengths of prompts.
prompt_lens: Lengths of prompts.
slot_mapping: The address to write the new KV to of each token.
slot_mapping: The address to write the new KV to of each token.
context_lens: the length of attention context for each generation token.
max_context_len: The maximum context length.
max_context_len: The maximum context length.
context_lens: the length of attention context for each sequence.
block_tables: The block tables. (Seq id -> list of physical block)
block_tables: The block tables. (Seq id -> list of physical block)
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
seq_groups
:
List
[
Tuple
[
List
[
int
],
SamplingParams
]],
seq_data
:
Dict
[
int
,
SequenceData
],
prompt_lens
:
List
[
int
],
prompt_lens
:
List
[
int
],
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
context_lens
:
torch
.
Tensor
,
max_context_len
:
Optional
[
int
],
max_context_len
:
int
,
context_lens
:
Optional
[
torch
.
Tensor
],
block_tables
:
torch
.
Tensor
,
block_tables
:
Optional
[
torch
.
Tensor
],
selected_token_indices
:
torch
.
Tensor
,
categorized_sample_indices
:
Dict
[
SamplingType
,
torch
.
Tensor
],
sliding_window
:
Optional
[
int
]
=
None
,
)
->
None
:
)
->
None
:
self
.
seq_groups
=
seq_groups
self
.
seq_data
=
seq_data
self
.
prompt_lens
=
prompt_lens
self
.
prompt_lens
=
prompt_lens
self
.
max_context_len
=
max_context_len
self
.
slot_mapping
=
slot_mapping
self
.
slot_mapping
=
slot_mapping
self
.
context_lens
=
context_lens
self
.
context_lens
=
context_lens
self
.
max_context_len
=
max_context_len
self
.
block_tables
=
block_tables
self
.
block_tables
=
block_tables
self
.
selected_token_indices
=
selected_token_indices
self
.
categorized_sample_indices
=
categorized_sample_indices
self
.
max_prompt_len
=
max
(
prompt_lens
)
if
prompt_lens
else
0
self
.
to_cache
=
None
if
sliding_window
is
not
None
:
# We need to keep the positions of sliding windows within
# the key / value tables, this is helpful to know which
# elements we need to cache.
to_cache
,
start_idx
=
[],
0
for
prompt_len
in
self
.
prompt_lens
:
to_cache
.
extend
(
range
(
start_idx
+
max
(
0
,
prompt_len
-
sliding_window
),
start_idx
+
prompt_len
,
))
start_idx
+=
self
.
max_prompt_len
to_cache
.
extend
(
range
(
start_idx
,
slot_mapping
.
shape
[
0
]))
self
.
to_cache
=
torch
.
tensor
(
to_cache
,
dtype
=
torch
.
int32
,
device
=
self
.
slot_mapping
.
device
)
self
.
num_prompts
=
len
(
prompt_lens
)
self
.
num_prompt_tokens
=
self
.
num_prompts
*
self
.
max_prompt_len
self
.
num_generation_tokens
=
context_lens
.
shape
[
0
]
if
block_tables
.
numel
()
>
0
:
self
.
max_num_blocks_per_seq
=
block_tables
.
shape
[
1
]
else
:
self
.
max_num_blocks_per_seq
=
0
assert
block_tables
.
shape
[
0
]
==
self
.
num_generation_tokens
self
.
is_prompt
=
len
(
prompt_lens
)
>
0
# Set during the execution of the first attention op.
# Set during the execution of the first attention op.
self
.
attn_bias
:
Optional
[
AttentionBias
]
=
None
# FIXME(woosuk): This is a hack.
self
.
attn_bias
=
None
def
__repr__
(
self
)
->
str
:
def
__repr__
(
self
)
->
str
:
# Print only useful metadata.
return
(
"InputMetadata("
return
(
f
"prompt_lens=
{
self
.
prompt_lens
}
, "
f
'InputMetadata('
f
"max_context_len=
{
self
.
max_context_len
}
, "
f
'num_prompt_tokens=
{
self
.
num_prompt_tokens
}
, '
f
"slot_mapping=
{
self
.
slot_mapping
}
, "
f
'num_prompts=
{
self
.
num_prompts
}
, '
f
"context_lens=
{
self
.
context_lens
}
, "
f
'prompt_lens=
{
self
.
prompt_lens
}
, '
f
"block_tables=
{
self
.
block_tables
}
)"
)
f
'num_generation_tokens=
{
self
.
num_generation_tokens
}
, '
f
'context_lens=
{
self
.
context_lens
}
, '
f
'max_context_len=
{
self
.
max_context_len
}
), '
f
'max_num_blocks_per_seq=
{
self
.
max_num_blocks_per_seq
}
, '
f
'block_tables=
{
self
.
block_tables
}
, '
f
'selected_token_indices=
{
self
.
selected_token_indices
}
, '
f
'categorized_sample_indices=
{
self
.
categorized_sample_indices
}
, '
f
'slot_mapping=
{
self
.
slot_mapping
}
)'
)
vllm/model_executor/layers/attention.py
View file @
27feead2
...
@@ -101,23 +101,15 @@ class PagedAttention(nn.Module):
...
@@ -101,23 +101,15 @@ class PagedAttention(nn.Module):
# vectors will not be cached. This happens during the initial memory
# vectors will not be cached. This happens during the initial memory
# profiling run.
# profiling run.
if
key_cache
is
not
None
and
value_cache
is
not
None
:
if
key_cache
is
not
None
and
value_cache
is
not
None
:
key_to_cache
=
key
value_to_cache
=
value
if
input_metadata
.
to_cache
is
not
None
:
key_to_cache
=
key_to_cache
[
input_metadata
.
to_cache
]
value_to_cache
=
value_to_cache
[
input_metadata
.
to_cache
]
slot_mapping
=
slot_mapping
[
input_metadata
.
to_cache
]
cache_ops
.
reshape_and_cache
(
cache_ops
.
reshape_and_cache
(
key
_to_cache
,
key
,
value
_to_cache
,
value
,
key_cache
,
key_cache
,
value_cache
,
value_cache
,
slot_mapping
,
slot_mapping
,
)
)
is_prompt
=
len
(
input_metadata
.
prompt_lens
)
>
0
if
input_metadata
.
is_prompt
:
if
is_prompt
:
# Prompt run.
# Prompt run.
if
self
.
num_kv_heads
!=
self
.
num_heads
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
...
...
vllm/model_executor/layers/sampler.py
View file @
27feead2
...
@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple
...
@@ -4,9 +4,9 @@ from typing import Dict, List, Optional, Tuple
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.parallel_utils.communication_op
import
(
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_gather
)
tensor_model_parallel_all_gather
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
from
vllm.sequence
import
(
PromptLogprobs
,
SampleLogprobs
,
SamplerOutput
,
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
SequenceData
,
SequenceGroupOutput
,
SequenceOutput
)
...
@@ -37,29 +37,30 @@ class Sampler(nn.Module):
...
@@ -37,29 +37,30 @@ class Sampler(nn.Module):
self
,
self
,
embedding
:
torch
.
Tensor
,
embedding
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
sampling
_metadata
:
Sampling
Metadata
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
embedding_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
SamplerOutput
:
)
->
SamplerOutput
:
# Get the hidden states that we use for sampling.
# Get the hidden states that we use for sampling.
hidden_states
=
_prune_hidden_states
(
hidden_states
,
input
_metadata
)
hidden_states
=
_prune_hidden_states
(
hidden_states
,
sampling
_metadata
)
# Get the logits for the next tokens.
# Get the logits for the next tokens.
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
logits
=
_get_logits
(
hidden_states
,
embedding
,
embedding_bias
,
self
.
vocab_size
)
self
.
vocab_size
)
# Apply logits processors (if any).
# Apply logits processors (if any).
logits
=
_apply_logits_processors
(
logits
,
input
_metadata
)
logits
=
_apply_logits_processors
(
logits
,
sampling
_metadata
)
# Apply presence and frequency penalties.
# Apply presence and frequency penalties.
presence_penalties
,
frequency_penalties
,
repetition_penalties
=
(
presence_penalties
,
frequency_penalties
,
repetition_penalties
=
(
_get_penalties
(
input
_metadata
))
_get_penalties
(
sampling
_metadata
))
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
presence_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
frequency_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
assert
len
(
repetition_penalties
)
==
logits
.
shape
[
0
]
logits
=
_apply_penalties
(
logits
,
input_metadata
,
presence_penalties
,
logits
=
_apply_penalties
(
logits
,
sampling_metadata
,
frequency_penalties
,
repetition_penalties
)
presence_penalties
,
frequency_penalties
,
repetition_penalties
)
# Apply temperature scaling.
# Apply temperature scaling.
temperatures
=
_get_temperatures
(
input
_metadata
)
temperatures
=
_get_temperatures
(
sampling
_metadata
)
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
assert
len
(
temperatures
)
==
logits
.
shape
[
0
]
if
any
(
t
!=
1.0
for
t
in
temperatures
):
if
any
(
t
!=
1.0
for
t
in
temperatures
):
t
=
torch
.
tensor
(
temperatures
,
t
=
torch
.
tensor
(
temperatures
,
...
@@ -70,7 +71,7 @@ class Sampler(nn.Module):
...
@@ -70,7 +71,7 @@ class Sampler(nn.Module):
# Apply top-p and top-k truncation.
# Apply top-p and top-k truncation.
top_ps
,
top_ks
,
min_ps
=
_get_top_p_top_k_min_p
(
top_ps
,
top_ks
,
min_ps
=
_get_top_p_top_k_min_p
(
input
_metadata
,
self
.
vocab_size
)
sampling
_metadata
,
self
.
vocab_size
)
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
logits
.
shape
[
0
]
assert
len
(
top_ps
)
==
len
(
top_ks
)
==
logits
.
shape
[
0
]
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_p
=
any
(
p
<
1.0
-
_SAMPLING_EPS
for
p
in
top_ps
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
do_top_k
=
any
(
k
!=
self
.
vocab_size
for
k
in
top_ks
)
...
@@ -89,11 +90,11 @@ class Sampler(nn.Module):
...
@@ -89,11 +90,11 @@ class Sampler(nn.Module):
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
logprobs
=
torch
.
log_softmax
(
logits
,
dim
=-
1
,
dtype
=
torch
.
float
)
# Sample the next tokens.
# Sample the next tokens.
sample_results
=
_sample
(
probs
,
logprobs
,
input
_metadata
)
sample_results
=
_sample
(
probs
,
logprobs
,
sampling
_metadata
)
# Get the logprobs query results.
# Get the logprobs query results.
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
prompt_logprobs
,
sample_logprobs
=
_get_logprobs
(
logprobs
,
input
_metadata
,
sample_results
)
logprobs
,
sampling
_metadata
,
sample_results
)
return
_build_sampler_output
(
sample_results
,
input
_metadata
,
return
_build_sampler_output
(
sample_results
,
sampling
_metadata
,
prompt_logprobs
,
sample_logprobs
)
prompt_logprobs
,
sample_logprobs
)
...
@@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
...
@@ -112,29 +113,30 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
def
_prune_hidden_states
(
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
sampling
_metadata
:
Sampling
Metadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
shape
[
-
1
])
return
hidden_states
.
index_select
(
0
,
input_metadata
.
selected_token_indices
)
return
hidden_states
.
index_select
(
0
,
sampling_metadata
.
selected_token_indices
)
def
_get_penalties
(
def
_get_penalties
(
input
_metadata
:
Input
Metadata
sampling
_metadata
:
Sampling
Metadata
)
->
Tuple
[
List
[
float
],
List
[
float
],
List
[
float
]]:
)
->
Tuple
[
List
[
float
],
List
[
float
],
List
[
float
]]:
# Collect the presence and frequency penalties.
# Collect the presence and frequency penalties.
presence_penalties
:
List
[
float
]
=
[]
presence_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
frequency_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
repetition_penalties
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input
_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling
_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
p
=
sampling_params
.
presence_penalty
p
=
sampling_params
.
presence_penalty
f
=
sampling_params
.
frequency_penalty
f
=
sampling_params
.
frequency_penalty
r
=
sampling_params
.
repetition_penalty
r
=
sampling_params
.
repetition_penalty
if
(
i
<
input
_metadata
.
num_prompts
if
(
i
<
sampling
_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# NOTE: We do not apply presence and frequency penalties for the
# NOTE: We do not apply presence and frequency penalties for the
# prompt token positions where we don't sample new tokens.
# prompt token positions where we don't sample new tokens.
prompt_len
=
input
_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling
_metadata
.
prompt_lens
[
i
]
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
presence_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
frequency_penalties
+=
[
0
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
repetition_penalties
+=
[
1
]
*
(
prompt_len
-
1
)
...
@@ -145,21 +147,21 @@ def _get_penalties(
...
@@ -145,21 +147,21 @@ def _get_penalties(
def
_get_prompt_and_output_tokens
(
def
_get_prompt_and_output_tokens
(
input
_metadata
:
Input
Metadata
sampling
_metadata
:
Sampling
Metadata
,
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
)
->
Tuple
[
List
[
List
[
int
]],
List
[
List
[
int
]]]:
prompt_tokens
:
List
[
List
[
int
]]
=
[]
prompt_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
output_tokens
:
List
[
List
[
int
]]
=
[]
for
i
,
seq_group
in
enumerate
(
input
_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling
_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
if
(
i
<
input
_metadata
.
num_prompts
if
(
i
<
sampling
_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
# NOTE: prompt token positions do not need output tokens to
# NOTE: prompt token positions do not need output tokens to
# compute penalties.
# compute penalties.
prompt_len
=
input
_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling
_metadata
.
prompt_lens
[
i
]
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
prompt_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
output_tokens
.
extend
([]
for
_
in
range
(
prompt_len
-
1
))
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
seq_data
=
input
_metadata
.
seq_data
[
seq_id
]
seq_data
=
sampling
_metadata
.
seq_data
[
seq_id
]
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
prompt_tokens
.
append
(
seq_data
.
prompt_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
output_tokens
.
append
(
seq_data
.
output_token_ids
)
return
prompt_tokens
,
output_tokens
return
prompt_tokens
,
output_tokens
...
@@ -191,17 +193,19 @@ def _get_bin_counts_and_mask(
...
@@ -191,17 +193,19 @@ def _get_bin_counts_and_mask(
return
bin_counts
,
mask
return
bin_counts
,
mask
def
_apply_logits_processors
(
logits
:
torch
.
Tensor
,
def
_apply_logits_processors
(
input_metadata
:
InputMetadata
)
->
torch
.
Tensor
:
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
logits_row_idx
=
0
logits_row_idx
=
0
found_logits_processors
=
False
found_logits_processors
=
False
for
seq_ids
,
sampling_params
in
input
_metadata
.
seq_groups
:
for
seq_ids
,
sampling_params
in
sampling
_metadata
.
seq_groups
:
logits_processors
=
sampling_params
.
logits_processors
logits_processors
=
sampling_params
.
logits_processors
if
logits_processors
:
if
logits_processors
:
found_logits_processors
=
True
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
for
seq_id
in
seq_ids
:
logits_row
=
logits
[
logits_row_idx
]
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
input
_metadata
.
seq_data
[
seq_id
].
output_token_ids
token_ids
=
sampling
_metadata
.
seq_data
[
seq_id
].
output_token_ids
for
logits_processor
in
logits_processors
:
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits
[
logits_row_idx
]
=
logits_row
...
@@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor,
...
@@ -215,7 +219,7 @@ def _apply_logits_processors(logits: torch.Tensor,
def
_apply_penalties
(
def
_apply_penalties
(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
sampling
_metadata
:
Sampling
Metadata
,
presence_penalties
:
List
[
float
],
presence_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
frequency_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
repetition_penalties
:
List
[
float
],
...
@@ -234,7 +238,7 @@ def _apply_penalties(
...
@@ -234,7 +238,7 @@ def _apply_penalties(
return
logits
return
logits
prompt_tokens
,
output_tokens
=
(
prompt_tokens
,
output_tokens
=
(
_get_prompt_and_output_tokens
(
input
_metadata
))
_get_prompt_and_output_tokens
(
sampling
_metadata
))
assert
len
(
prompt_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
prompt_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
assert
len
(
output_tokens
)
==
logits
.
shape
[
0
]
...
@@ -265,10 +269,10 @@ def _apply_penalties(
...
@@ -265,10 +269,10 @@ def _apply_penalties(
return
logits
return
logits
def
_get_temperatures
(
input
_metadata
:
Input
Metadata
)
->
List
[
float
]:
def
_get_temperatures
(
sampling
_metadata
:
Sampling
Metadata
)
->
List
[
float
]:
# Collect the temperatures for the logits.
# Collect the temperatures for the logits.
temperatures
:
List
[
float
]
=
[]
temperatures
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input
_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling
_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
temperature
=
sampling_params
.
temperature
temperature
=
sampling_params
.
temperature
if
temperature
<
_SAMPLING_EPS
:
if
temperature
<
_SAMPLING_EPS
:
...
@@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
...
@@ -276,22 +280,22 @@ def _get_temperatures(input_metadata: InputMetadata) -> List[float]:
# (i.e., greedy sampling or beam search).
# (i.e., greedy sampling or beam search).
# Set the temperature to 1 to avoid division by zero.
# Set the temperature to 1 to avoid division by zero.
temperature
=
1.0
temperature
=
1.0
if
(
i
<
input
_metadata
.
num_prompts
if
(
i
<
sampling
_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_len
=
input
_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling
_metadata
.
prompt_lens
[
i
]
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
(
prompt_len
-
1
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
temperatures
+=
[
temperature
]
*
len
(
seq_ids
)
return
temperatures
return
temperatures
def
_get_top_p_top_k_min_p
(
def
_get_top_p_top_k_min_p
(
input
_metadata
:
Input
Metadata
,
sampling
_metadata
:
Sampling
Metadata
,
vocab_size
:
int
,
vocab_size
:
int
,
)
->
Tuple
[
List
[
float
],
List
[
int
],
List
[
float
]]:
)
->
Tuple
[
List
[
float
],
List
[
int
],
List
[
float
]]:
top_ps
:
List
[
float
]
=
[]
top_ps
:
List
[
float
]
=
[]
top_ks
:
List
[
int
]
=
[]
top_ks
:
List
[
int
]
=
[]
min_ps
:
List
[
float
]
=
[]
min_ps
:
List
[
float
]
=
[]
for
i
,
seq_group
in
enumerate
(
input
_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling
_metadata
.
seq_groups
):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
top_p
=
sampling_params
.
top_p
top_p
=
sampling_params
.
top_p
min_p
=
sampling_params
.
min_p
min_p
=
sampling_params
.
min_p
...
@@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p(
...
@@ -299,9 +303,9 @@ def _get_top_p_top_k_min_p(
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
top_k
=
min
(
sampling_params
.
top_k
,
vocab_size
)
# k=-1 means no truncation.
# k=-1 means no truncation.
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
top_k
=
vocab_size
if
top_k
==
-
1
else
top_k
if
(
i
<
input
_metadata
.
num_prompts
if
(
i
<
sampling
_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
prompt_len
=
input
_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling
_metadata
.
prompt_lens
[
i
]
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ps
+=
[
top_p
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
top_ks
+=
[
top_k
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
min_ps
+=
[
min_p
]
*
(
prompt_len
-
1
)
...
@@ -471,11 +475,11 @@ def _beam_search_sample(
...
@@ -471,11 +475,11 @@ def _beam_search_sample(
def
_sample
(
def
_sample
(
probs
:
torch
.
Tensor
,
probs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
sampling
_metadata
:
Sampling
Metadata
,
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
)
->
List
[
Tuple
[
List
[
int
],
List
[
int
]]]:
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_seq_group_ids
=
{
t
:
[]
for
t
in
SamplingType
}
categorized_sample_indices
=
input
_metadata
.
categorized_sample_indices
categorized_sample_indices
=
sampling
_metadata
.
categorized_sample_indices
for
i
,
seq_group
in
enumerate
(
input
_metadata
.
seq_groups
):
for
i
,
seq_group
in
enumerate
(
sampling
_metadata
.
seq_groups
):
_
,
sampling_params
=
seq_group
_
,
sampling_params
=
seq_group
sampling_type
=
sampling_params
.
sampling_type
sampling_type
=
sampling_params
.
sampling_type
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
categorized_seq_group_ids
[
sampling_type
].
append
(
i
)
...
@@ -483,8 +487,8 @@ def _sample(
...
@@ -483,8 +487,8 @@ def _sample(
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
sample_results_dict
:
Dict
[
int
,
Tuple
[
List
[
int
],
List
[
int
]]]
=
{}
for
sampling_type
in
SamplingType
:
for
sampling_type
in
SamplingType
:
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_group_ids
=
categorized_seq_group_ids
[
sampling_type
]
seq_groups
=
[
input
_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
seq_groups
=
[
sampling
_metadata
.
seq_groups
[
i
]
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
input
_metadata
.
num_prompts
for
i
in
seq_group_ids
]
is_prompts
=
[
i
<
sampling
_metadata
.
num_prompts
for
i
in
seq_group_ids
]
sample_indices
=
categorized_sample_indices
[
sampling_type
]
sample_indices
=
categorized_sample_indices
[
sampling_type
]
num_tokens
=
len
(
sample_indices
)
num_tokens
=
len
(
sample_indices
)
if
num_tokens
==
0
:
if
num_tokens
==
0
:
...
@@ -499,21 +503,22 @@ def _sample(
...
@@ -499,21 +503,22 @@ def _sample(
elif
sampling_type
==
SamplingType
.
BEAM
:
elif
sampling_type
==
SamplingType
.
BEAM
:
category_logprobs
=
logprobs
[
sample_indices
]
category_logprobs
=
logprobs
[
sample_indices
]
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
sample_results
=
_beam_search_sample
(
seq_groups
,
is_prompts
,
input
_metadata
.
seq_data
,
sampling
_metadata
.
seq_data
,
category_logprobs
)
category_logprobs
)
else
:
else
:
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
raise
ValueError
(
f
"Unsupported sampling type:
{
sampling_type
}
"
)
sample_results_dict
.
update
(
zip
(
seq_group_ids
,
sample_results
))
sample_results_dict
.
update
(
zip
(
seq_group_ids
,
sample_results
))
sample_results
=
[
sample_results
=
[
sample_results_dict
[
i
]
for
i
in
range
(
len
(
input_metadata
.
seq_groups
))
sample_results_dict
[
i
]
for
i
in
range
(
len
(
sampling_metadata
.
seq_groups
))
]
]
return
sample_results
return
sample_results
def
_get_logprobs
(
def
_get_logprobs
(
logprobs
:
torch
.
Tensor
,
logprobs
:
torch
.
Tensor
,
input
_metadata
:
Input
Metadata
,
sampling
_metadata
:
Sampling
Metadata
,
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
)
->
Tuple
[
List
[
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]],
List
[
List
[
Dict
[
)
->
Tuple
[
List
[
Optional
[
List
[
Optional
[
Dict
[
int
,
float
]]]]],
List
[
List
[
Dict
[
int
,
float
]]]]:
int
,
float
]]]]:
...
@@ -523,16 +528,16 @@ def _get_logprobs(
...
@@ -523,16 +528,16 @@ def _get_logprobs(
largest_num_logprobs
=
0
largest_num_logprobs
=
0
sample_idx
=
0
sample_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
input
_metadata
.
seq_groups
,
sample_results
)):
zip
(
sampling
_metadata
.
seq_groups
,
sample_results
)):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
next_token_ids
,
parent_ids
=
sample_result
next_token_ids
,
parent_ids
=
sample_result
num_parent_seqs
=
len
(
seq_ids
)
num_parent_seqs
=
len
(
seq_ids
)
if
(
i
<
input
_metadata
.
num_prompts
if
(
i
<
sampling
_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
largest_num_logprobs
=
max
(
largest_num_logprobs
,
largest_num_logprobs
=
max
(
largest_num_logprobs
,
sampling_params
.
prompt_logprobs
)
sampling_params
.
prompt_logprobs
)
prompt_len
=
input
_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling
_metadata
.
prompt_lens
[
i
]
prompt_tokens
=
input
_metadata
.
seq_data
[
prompt_tokens
=
sampling
_metadata
.
seq_data
[
seq_ids
[
0
]].
prompt_token_ids
seq_ids
[
0
]].
prompt_token_ids
batched_logprobs_query_seq_indices
.
extend
(
batched_logprobs_query_seq_indices
.
extend
(
sample_idx
+
j
for
j
in
range
(
prompt_len
-
1
))
sample_idx
+
j
for
j
in
range
(
prompt_len
-
1
))
...
@@ -570,16 +575,16 @@ def _get_logprobs(
...
@@ -570,16 +575,16 @@ def _get_logprobs(
sample_idx
=
0
sample_idx
=
0
query_result_idx
=
0
query_result_idx
=
0
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
for
i
,
(
seq_group
,
sample_result
)
in
enumerate
(
zip
(
input
_metadata
.
seq_groups
,
sample_results
)):
zip
(
sampling
_metadata
.
seq_groups
,
sample_results
)):
seq_ids
,
sampling_params
=
seq_group
seq_ids
,
sampling_params
=
seq_group
next_token_ids
,
parent_ids
=
sample_result
next_token_ids
,
parent_ids
=
sample_result
# Prompt logprobs
# Prompt logprobs
if
(
i
<
input
_metadata
.
num_prompts
if
(
i
<
sampling
_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
and
sampling_params
.
prompt_logprobs
is
not
None
):
num_logprobs
=
sampling_params
.
prompt_logprobs
num_logprobs
=
sampling_params
.
prompt_logprobs
prompt_len
=
input
_metadata
.
prompt_lens
[
i
]
prompt_len
=
sampling
_metadata
.
prompt_lens
[
i
]
prompt_tokens
=
input
_metadata
.
seq_data
[
prompt_tokens
=
sampling
_metadata
.
seq_data
[
seq_ids
[
0
]].
prompt_token_ids
seq_ids
[
0
]].
prompt_token_ids
group_prompt_logprobs
:
PromptLogprobs
=
[
None
]
group_prompt_logprobs
:
PromptLogprobs
=
[
None
]
for
token_id
in
prompt_tokens
[
1
:]:
for
token_id
in
prompt_tokens
[
1
:]:
...
@@ -625,13 +630,13 @@ def _get_logprobs(
...
@@ -625,13 +630,13 @@ def _get_logprobs(
def
_build_sampler_output
(
def
_build_sampler_output
(
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
sample_results
:
List
[
Tuple
[
List
[
int
],
List
[
int
]]],
input
_metadata
:
Input
Metadata
,
sampling
_metadata
:
Sampling
Metadata
,
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
prompt_logprobs
:
List
[
Optional
[
PromptLogprobs
]],
sample_logprobs
:
List
[
SampleLogprobs
],
sample_logprobs
:
List
[
SampleLogprobs
],
)
->
SamplerOutput
:
)
->
SamplerOutput
:
sampler_output
=
[]
sampler_output
=
[]
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
for
(
seq_group
,
sample_result
,
group_prompt_logprobs
,
group_sample_logprobs
)
in
zip
(
input
_metadata
.
seq_groups
,
group_sample_logprobs
)
in
zip
(
sampling
_metadata
.
seq_groups
,
sample_results
,
prompt_logprobs
,
sample_results
,
prompt_logprobs
,
sample_logprobs
):
sample_logprobs
):
seq_ids
,
_
=
seq_group
seq_ids
,
_
=
seq_group
...
...
vllm/model_executor/models/aquila.py
View file @
27feead2
...
@@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -39,6 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module):
...
@@ -296,11 +297,18 @@ class AquilaForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/baichuan.py
View file @
27feead2
...
@@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -38,6 +38,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
...
@@ -311,11 +312,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/bloom.py
View file @
27feead2
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module):
...
@@ -288,11 +289,18 @@ class BloomForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/chatglm.py
View file @
27feead2
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -22,6 +22,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module):
...
@@ -350,11 +351,18 @@ class ChatGLMForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/falcon.py
View file @
27feead2
...
@@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
...
@@ -41,6 +41,7 @@ from vllm.model_executor.parallel_utils.communication_op import (
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
...
@@ -389,7 +390,7 @@ class FalconForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
hidden_states
=
self
.
transformer
(
input_ids
,
input_ids
,
positions
,
positions
,
...
@@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
...
@@ -397,9 +398,15 @@ class FalconForCausalLM(nn.Module):
input_metadata
,
input_metadata
,
cache_events
,
cache_events
,
)
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
return
hidden_states
input_metadata
)
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
sampling_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt2.py
View file @
27feead2
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module):
...
@@ -232,11 +233,18 @@ class GPT2LMHeadModel(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_bigcode.py
View file @
27feead2
...
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -36,6 +36,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
...
@@ -251,11 +252,18 @@ class GPTBigCodeForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_j.py
View file @
27feead2
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module):
...
@@ -238,11 +239,18 @@ class GPTJForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
,
self
.
lm_head
.
bias
)
sampling
_metadata
,
self
.
lm_head
.
bias
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/gpt_neox.py
View file @
27feead2
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -35,6 +35,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module):
...
@@ -251,11 +252,18 @@ class GPTNeoXForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
gpt_neox
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
embed_out
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/internlm.py
View file @
27feead2
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -19,6 +19,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module):
...
@@ -250,11 +251,18 @@ class InternLMForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/llama.py
View file @
27feead2
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module):
...
@@ -289,11 +290,18 @@ class LlamaForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/mistral.py
View file @
27feead2
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -41,6 +41,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
,
ParallelLMHead
)
VocabParallelEmbedding
,
ParallelLMHead
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module):
...
@@ -285,11 +286,18 @@ class MistralForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
model
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
vllm/model_executor/models/mpt.py
View file @
27feead2
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
...
@@ -18,6 +18,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
from
vllm.model_executor.weight_utils
import
(
default_weight_loader
,
hf_model_weights_iterator
)
hf_model_weights_iterator
)
from
vllm.sequence
import
SamplerOutput
from
vllm.sequence
import
SamplerOutput
...
@@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module):
...
@@ -256,11 +257,18 @@ class MPTForCausalLM(nn.Module):
kv_caches
:
List
[
KVCache
],
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
SamplerOutput
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
)
input_metadata
,
cache_events
)
return
hidden_states
def
sample
(
self
,
hidden_states
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
)
->
SamplerOutput
:
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
next_tokens
=
self
.
sampler
(
self
.
lm_head_weight
,
hidden_states
,
input
_metadata
)
sampling
_metadata
)
return
next_tokens
return
next_tokens
def
load_weights
(
self
,
def
load_weights
(
self
,
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment