Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eebad39f
Unverified
Commit
eebad39f
authored
Nov 22, 2024
by
youkaichao
Committed by
GitHub
Nov 22, 2024
Browse files
[torch.compile] support all attention backends (#10558)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
db100c5c
Changes
77
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
40 additions
and
17 deletions
+40
-17
vllm/model_executor/models/starcoder2.py
vllm/model_executor/models/starcoder2.py
+10
-5
vllm/model_executor/models/xverse.py
vllm/model_executor/models/xverse.py
+7
-3
vllm/platforms/cpu.py
vllm/platforms/cpu.py
+1
-0
vllm/platforms/cuda.py
vllm/platforms/cuda.py
+1
-0
vllm/platforms/hpu.py
vllm/platforms/hpu.py
+1
-0
vllm/platforms/interface.py
vllm/platforms/interface.py
+4
-0
vllm/platforms/openvino.py
vllm/platforms/openvino.py
+1
-0
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+1
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+1
-0
vllm/platforms/xpu.py
vllm/platforms/xpu.py
+1
-0
vllm/spec_decode/draft_model_runner.py
vllm/spec_decode/draft_model_runner.py
+2
-1
vllm/utils.py
vllm/utils.py
+2
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+2
-2
vllm/worker/embedding_model_runner.py
vllm/worker/embedding_model_runner.py
+1
-1
vllm/worker/enc_dec_model_runner.py
vllm/worker/enc_dec_model_runner.py
+1
-1
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+2
-2
No files found.
vllm/model_executor/models/starcoder2.py
View file @
eebad39f
...
@@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
...
@@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
Starcoder2Config
,
config
:
Starcoder2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
...
@@ -105,7 +106,8 @@ class Starcoder2Attention(nn.Module):
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
...
@@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module):
def
__init__
(
self
,
def
__init__
(
self
,
config
:
Starcoder2Config
,
config
:
Starcoder2Config
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
Starcoder2Attention
(
config
,
self
.
self_attn
=
Starcoder2Attention
(
config
,
cache_config
,
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
)
self
.
mlp
=
Starcoder2MLP
(
config
,
quant_config
=
quant_config
)
self
.
mlp
=
Starcoder2MLP
(
config
,
quant_config
=
quant_config
)
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
self
.
input_layernorm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
eps
=
config
.
norm_epsilon
)
...
@@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
...
@@ -213,7 +217,8 @@ class Starcoder2Model(nn.Module):
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
Starcoder2DecoderLayer
(
lambda
prefix
:
Starcoder2DecoderLayer
(
config
,
cache_config
,
quant_config
=
quant_config
),
config
,
cache_config
,
quant_config
=
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
)
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
self
.
norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
norm_epsilon
)
...
...
vllm/model_executor/models/xverse.py
View file @
eebad39f
...
@@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
...
@@ -93,6 +93,7 @@ class XverseAttention(nn.Module):
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
bias
:
bool
=
False
,
bias
:
bool
=
False
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
...
@@ -138,7 +139,8 @@ class XverseAttention(nn.Module):
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
cache_config
=
cache_config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
)
quant_config
=
quant_config
,
prefix
=
f
"
{
prefix
}
.attn"
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
...
@@ -162,6 +164,7 @@ class XverseDecoderLayer(nn.Module):
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
...
@@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
...
@@ -180,6 +183,7 @@ class XverseDecoderLayer(nn.Module):
quant_config
=
quant_config
,
quant_config
=
quant_config
,
bias
=
getattr
(
config
,
"bias"
,
False
),
bias
=
getattr
(
config
,
"bias"
,
False
),
cache_config
=
cache_config
,
cache_config
=
cache_config
,
prefix
=
f
"
{
prefix
}
.self_attn"
,
)
)
self
.
mlp
=
XverseMLP
(
self
.
mlp
=
XverseMLP
(
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
...
@@ -243,8 +247,8 @@ class XverseModel(nn.Module):
...
@@ -243,8 +247,8 @@ class XverseModel(nn.Module):
)
)
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
self
.
start_layer
,
self
.
end_layer
,
self
.
layers
=
make_layers
(
config
.
num_hidden_layers
,
config
.
num_hidden_layers
,
lambda
prefix
:
XverseDecoderLayer
(
config
,
cache_config
,
lambda
prefix
:
XverseDecoderLayer
(
quant_config
),
config
,
cache_config
,
quant_config
,
prefix
=
prefix
),
prefix
=
f
"
{
prefix
}
.layers"
,
prefix
=
f
"
{
prefix
}
.layers"
,
)
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
...
...
vllm/platforms/cpu.py
View file @
eebad39f
...
@@ -20,6 +20,7 @@ logger = init_logger(__name__)
...
@@ -20,6 +20,7 @@ logger = init_logger(__name__)
class
CpuPlatform
(
Platform
):
class
CpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
CPU
_enum
=
PlatformEnum
.
CPU
device_type
:
str
=
"cpu"
device_type
:
str
=
"cpu"
dispatch_key
:
str
=
"CPU"
@
classmethod
@
classmethod
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
def
get_device_name
(
cls
,
device_id
:
int
=
0
)
->
str
:
...
...
vllm/platforms/cuda.py
View file @
eebad39f
...
@@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
...
@@ -121,6 +121,7 @@ def device_id_to_physical_device_id(device_id: int) -> int:
class
CudaPlatform
(
Platform
):
class
CudaPlatform
(
Platform
):
_enum
=
PlatformEnum
.
CUDA
_enum
=
PlatformEnum
.
CUDA
device_type
:
str
=
"cuda"
device_type
:
str
=
"cuda"
dispatch_key
:
str
=
"CUDA"
@
classmethod
@
classmethod
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
DeviceCapability
:
def
get_device_capability
(
cls
,
device_id
:
int
=
0
)
->
DeviceCapability
:
...
...
vllm/platforms/hpu.py
View file @
eebad39f
...
@@ -13,6 +13,7 @@ else:
...
@@ -13,6 +13,7 @@ else:
class
HpuPlatform
(
Platform
):
class
HpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
HPU
_enum
=
PlatformEnum
.
HPU
device_type
:
str
=
"hpu"
device_type
:
str
=
"hpu"
dispatch_key
:
str
=
"HPU"
@
classmethod
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
...
...
vllm/platforms/interface.py
View file @
eebad39f
...
@@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
...
@@ -57,6 +57,10 @@ class DeviceCapability(NamedTuple):
class
Platform
:
class
Platform
:
_enum
:
PlatformEnum
_enum
:
PlatformEnum
device_type
:
str
device_type
:
str
# available dispatch keys:
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
# use "CPU" as a fallback for platforms not registered in PyTorch
dispatch_key
:
str
=
"CPU"
def
is_cuda
(
self
)
->
bool
:
def
is_cuda
(
self
)
->
bool
:
return
self
.
_enum
==
PlatformEnum
.
CUDA
return
self
.
_enum
==
PlatformEnum
.
CUDA
...
...
vllm/platforms/openvino.py
View file @
eebad39f
...
@@ -18,6 +18,7 @@ logger = init_logger(__name__)
...
@@ -18,6 +18,7 @@ logger = init_logger(__name__)
class
OpenVinoPlatform
(
Platform
):
class
OpenVinoPlatform
(
Platform
):
_enum
=
PlatformEnum
.
OPENVINO
_enum
=
PlatformEnum
.
OPENVINO
device_type
:
str
=
"openvino"
device_type
:
str
=
"openvino"
dispatch_key
:
str
=
"CPU"
@
classmethod
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
...
...
vllm/platforms/rocm.py
View file @
eebad39f
...
@@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
...
@@ -36,6 +36,7 @@ if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]:
class
RocmPlatform
(
Platform
):
class
RocmPlatform
(
Platform
):
_enum
=
PlatformEnum
.
ROCM
_enum
=
PlatformEnum
.
ROCM
device_type
:
str
=
"cuda"
device_type
:
str
=
"cuda"
dispatch_key
:
str
=
"CUDA"
@
classmethod
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
...
...
vllm/platforms/tpu.py
View file @
eebad39f
...
@@ -17,6 +17,7 @@ logger = init_logger(__name__)
...
@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class
TpuPlatform
(
Platform
):
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
_enum
=
PlatformEnum
.
TPU
device_type
:
str
=
"tpu"
device_type
:
str
=
"tpu"
dispatch_key
:
str
=
"XLA"
@
classmethod
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
...
...
vllm/platforms/xpu.py
View file @
eebad39f
...
@@ -17,6 +17,7 @@ logger = init_logger(__name__)
...
@@ -17,6 +17,7 @@ logger = init_logger(__name__)
class
XPUPlatform
(
Platform
):
class
XPUPlatform
(
Platform
):
_enum
=
PlatformEnum
.
XPU
_enum
=
PlatformEnum
.
XPU
device_type
:
str
=
"xpu"
device_type
:
str
=
"xpu"
dispatch_key
:
str
=
"XPU"
@
classmethod
@
classmethod
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
def
get_default_attn_backend
(
cls
,
selected_backend
:
_Backend
)
->
_Backend
:
...
...
vllm/spec_decode/draft_model_runner.py
View file @
eebad39f
...
@@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
...
@@ -273,7 +273,8 @@ class TP1DraftModelRunner(ModelRunner):
if
previous_hidden_states
is
not
None
else
{}
if
previous_hidden_states
is
not
None
else
{}
# Run model
# Run model
with
set_forward_context
(
model_input
.
attn_metadata
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
model_executable
(
hidden_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
vllm/utils.py
View file @
eebad39f
...
@@ -1573,6 +1573,7 @@ def direct_register_custom_op(
...
@@ -1573,6 +1573,7 @@ def direct_register_custom_op(
mutates_args
:
List
[
str
],
mutates_args
:
List
[
str
],
fake_impl
:
Optional
[
Callable
]
=
None
,
fake_impl
:
Optional
[
Callable
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
target_lib
:
Optional
[
Library
]
=
None
,
dispatch_key
:
str
=
"CUDA"
,
):
):
"""
"""
`torch.library.custom_op` can have significant overhead because it
`torch.library.custom_op` can have significant overhead because it
...
@@ -1601,7 +1602,7 @@ def direct_register_custom_op(
...
@@ -1601,7 +1602,7 @@ def direct_register_custom_op(
schema_str
=
torch
.
_custom_op
.
impl
.
infer_schema
(
op_func
,
mutates_args
)
schema_str
=
torch
.
_custom_op
.
impl
.
infer_schema
(
op_func
,
mutates_args
)
my_lib
=
target_lib
or
vllm_lib
my_lib
=
target_lib
or
vllm_lib
my_lib
.
define
(
op_name
+
schema_str
)
my_lib
.
define
(
op_name
+
schema_str
)
my_lib
.
impl
(
op_name
,
op_func
,
"CUDA"
)
my_lib
.
impl
(
op_name
,
op_func
,
dispatch_key
=
dispatch_key
)
if
fake_impl
is
not
None
:
if
fake_impl
is
not
None
:
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
my_lib
.
_register_fake
(
op_name
,
fake_impl
)
...
...
vllm/v1/attention/backends/flash_attn.py
View file @
eebad39f
...
@@ -173,7 +173,8 @@ def unified_v1_flash_attention(
...
@@ -173,7 +173,8 @@ def unified_v1_flash_attention(
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
alibi_slopes
:
Optional
[
torch
.
Tensor
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
logits_soft_cap
:
Optional
[
float
]
=
None
,
)
->
None
:
)
->
None
:
current_metadata
=
get_forward_context
()
context
=
get_forward_context
()
current_metadata
=
context
.
dynamic_forward_context
if
current_metadata
is
None
:
if
current_metadata
is
None
:
# Profiling run.
# Profiling run.
return
return
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
eebad39f
...
@@ -447,7 +447,7 @@ class GPUModelRunner:
...
@@ -447,7 +447,7 @@ class GPUModelRunner:
# Run the decoder.
# Run the decoder.
# Use persistent buffers for CUDA graphs.
# Use persistent buffers for CUDA graphs.
with
set_forward_context
(
attn_metadata
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
=
None
,
input_ids
=
None
,
positions
=
self
.
positions
[:
num_input_tokens
],
positions
=
self
.
positions
[:
num_input_tokens
],
...
@@ -523,7 +523,7 @@ class GPUModelRunner:
...
@@ -523,7 +523,7 @@ class GPUModelRunner:
num_tokens
:
int
,
num_tokens
:
int
,
kv_caches
:
List
[
torch
.
Tensor
],
kv_caches
:
List
[
torch
.
Tensor
],
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
with
set_forward_context
(
None
):
with
set_forward_context
(
None
,
self
.
vllm_config
):
hidden_states
=
model
(
hidden_states
=
model
(
input_ids
=
None
,
input_ids
=
None
,
positions
=
self
.
positions
[:
num_tokens
],
positions
=
self
.
positions
[:
num_tokens
],
...
...
vllm/worker/embedding_model_runner.py
View file @
eebad39f
...
@@ -97,7 +97,7 @@ class EmbeddingModelRunner(
...
@@ -97,7 +97,7 @@ class EmbeddingModelRunner(
model_forward_end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
model_forward_end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
model_forward_start
.
record
()
model_forward_start
.
record
()
with
set_forward_context
(
model_input
.
attn_metadata
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
hidden_or_intermediate_states
=
model_executable
(
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/enc_dec_model_runner.py
View file @
eebad39f
...
@@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
...
@@ -176,7 +176,7 @@ class EncoderDecoderModelRunner(GPUModelRunnerBase[EncoderDecoderModelInput]):
}
if
self
.
has_inner_state
else
{}
}
if
self
.
has_inner_state
else
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
multi_modal_kwargs
=
model_input
.
multi_modal_kwargs
or
{}
with
set_forward_context
(
model_input
.
attn_metadata
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
hidden_or_intermediate_states
=
model_executable
(
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
vllm/worker/model_runner.py
View file @
eebad39f
...
@@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
...
@@ -1503,7 +1503,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self
.
_update_inputs_to_capture_for_enc_dec_model
(
self
.
_update_inputs_to_capture_for_enc_dec_model
(
capture_inputs
)
capture_inputs
)
with
set_forward_context
(
attn_metadata
):
with
set_forward_context
(
attn_metadata
,
self
.
vllm_config
):
graph_runner
.
capture
(
**
capture_inputs
)
graph_runner
.
capture
(
**
capture_inputs
)
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_memory_pool
=
graph_runner
.
graph
.
pool
()
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
self
.
graph_runners
[
virtual_engine
][
batch_size
]
=
(
...
@@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
...
@@ -1649,7 +1649,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
model_forward_end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
model_forward_end
=
torch
.
cuda
.
Event
(
enable_timing
=
True
)
model_forward_start
.
record
()
model_forward_start
.
record
()
with
set_forward_context
(
model_input
.
attn_metadata
):
with
set_forward_context
(
model_input
.
attn_metadata
,
self
.
vllm_config
):
hidden_or_intermediate_states
=
model_executable
(
hidden_or_intermediate_states
=
model_executable
(
input_ids
=
model_input
.
input_tokens
,
input_ids
=
model_input
.
input_tokens
,
positions
=
model_input
.
input_positions
,
positions
=
model_input
.
input_positions
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment