Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b91660dd
Unverified
Commit
b91660dd
authored
Feb 28, 2025
by
Kacper Pietkun
Committed by
GitHub
Feb 28, 2025
Browse files
[Hardware][Intel-Gaudi] Regional compilation support (#13213)
parent
76c89fca
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
6 deletions
+37
-6
vllm/worker/hpu_model_runner.py
vllm/worker/hpu_model_runner.py
+37
-6
No files found.
vllm/worker/hpu_model_runner.py
View file @
b91660dd
...
@@ -39,7 +39,10 @@ from vllm.lora.layers import LoRAMapping
...
@@ -39,7 +39,10 @@ from vllm.lora.layers import LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.lora.worker_manager
import
LRUCacheWorkerLoRAManager
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor
import
SamplingMetadata
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.sampler
import
SamplerOutput
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
from
vllm.multimodal
import
(
MULTIMODAL_REGISTRY
,
BatchedTensorInputs
,
MultiModalKwargs
)
MultiModalKwargs
)
...
@@ -311,10 +314,38 @@ class HpuModelAdapter:
...
@@ -311,10 +314,38 @@ class HpuModelAdapter:
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
block_size
=
vllm_config
.
cache_config
.
block_size
self
.
dtype
=
vllm_config
.
model_config
.
dtype
self
.
dtype
=
vllm_config
.
model_config
.
dtype
enforce_eager
=
vllm_config
.
model_config
.
enforce_eager
enforce_eager
=
vllm_config
.
model_config
.
enforce_eager
if
not
htorch
.
utils
.
internal
.
is_lazy
()
and
not
enforce_eager
:
if
not
htorch
.
utils
.
internal
.
is_lazy
()
and
not
enforce_eager
:
self
.
model
=
torch
.
compile
(
self
.
model
,
if
os
.
getenv
(
'VLLM_REGIONAL_COMPILATION'
,
backend
=
'hpu_backend'
,
'true'
).
lower
()
==
'true'
:
dynamic
=
False
)
self
.
regional_compilation_layers_list
=
[
RMSNorm
,
VocabParallelEmbedding
]
self
.
_regional_compilation
(
self
.
model
)
else
:
self
.
model
=
torch
.
compile
(
self
.
model
,
backend
=
'hpu_backend'
,
dynamic
=
False
)
def
_regional_compilation
(
self
,
module
,
parent_module
=
None
,
module_name
=
None
):
if
isinstance
(
module
,
torch
.
nn
.
ModuleList
):
for
children_name
,
children_module
in
module
.
named_children
():
self
.
_compile_region
(
module
,
children_name
,
children_module
)
elif
any
(
isinstance
(
module
,
layer
)
for
layer
in
self
.
regional_compilation_layers_list
):
self
.
_compile_region
(
parent_module
,
module_name
,
module
)
else
:
for
children_name
,
children_module
in
module
.
named_children
():
self
.
_regional_compilation
(
children_module
,
module
,
children_name
)
def
_compile_region
(
self
,
model
,
name
,
module
):
module
=
torch
.
compile
(
module
,
backend
=
'hpu_backend'
,
dynamic
=
False
)
setattr
(
model
,
name
,
module
)
def
_set_attn_bias
(
self
,
attn_metadata
,
batch_size
,
seq_len
,
device
,
def
_set_attn_bias
(
self
,
attn_metadata
,
batch_size
,
seq_len
,
device
,
dtype
):
dtype
):
...
@@ -1575,9 +1606,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
...
@@ -1575,9 +1606,9 @@ class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]):
list
(
sorted
(
self
.
bucketing_global_state
.
decode_buckets
)))
list
(
sorted
(
self
.
bucketing_global_state
.
decode_buckets
)))
if
not
htorch
.
utils
.
internal
.
is_lazy
()
and
not
self
.
enforce_eager
:
if
not
htorch
.
utils
.
internal
.
is_lazy
()
and
not
self
.
enforce_eager
:
cache_size_limit
=
len
(
cache_size_limit
=
1
+
3
*
(
self
.
bucketing_global_state
.
prompt_buckets
)
+
len
(
len
(
self
.
bucketing_global_state
.
prompt_buckets
)
+
self
.
bucketing_global_state
.
decode_buckets
)
+
1
len
(
self
.
bucketing_global_state
.
decode_buckets
)
)
torch
.
_dynamo
.
config
.
cache_size_limit
=
max
(
torch
.
_dynamo
.
config
.
cache_size_limit
=
max
(
cache_size_limit
,
torch
.
_dynamo
.
config
.
cache_size_limit
)
cache_size_limit
,
torch
.
_dynamo
.
config
.
cache_size_limit
)
# Multiply by 8 to follow the original default ratio between
# Multiply by 8 to follow the original default ratio between
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment