Unverified Commit d5e0fca2 authored by Luka Govedič's avatar Luka Govedič Committed by GitHub
Browse files

[torch.compile] Cleanup compilation tests and custom passes, add debug utils,...


[torch.compile] Cleanup compilation tests and custom passes, add debug utils, fix DCE bug (#23091), fix test (#24376), and prep for custom op matching (#24604) (#24542)
Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Signed-off-by: default avatarluka <lgovedic@redhat.com>
Signed-off-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 8d0ee5a5
...@@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig, ...@@ -905,10 +905,9 @@ def set_current_vllm_config(vllm_config: VllmConfig,
except Exception: except Exception:
raise raise
else: else:
logger.debug("enabled custom ops: %s", if check_compile:
vllm_config.compilation_config.enabled_custom_ops) vllm_config.compilation_config.custom_op_log_check()
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if check_compile and \ if check_compile and \
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen: and compilation_counter.num_models_seen == num_models_seen:
......
...@@ -487,6 +487,12 @@ class CompilationConfig: ...@@ -487,6 +487,12 @@ class CompilationConfig:
"supported with torch>=2.9.0.dev. Set " "supported with torch>=2.9.0.dev. Set "
"use_inductor_graph_partition=False instead.") "use_inductor_graph_partition=False instead.")
for op in self.custom_ops:
if op[0] not in {'+', '-'} and op not in {'all', 'none'}:
raise ValueError(f"Invalid syntax '{op}' for custom op, "
"must be 'all', 'none', '+op' or '-op' "
"(where 'op' is the registered op name)")
def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]: def init_backend(self, vllm_config: "VllmConfig") -> Union[str, Callable]:
if self.level == CompilationLevel.NO_COMPILATION: if self.level == CompilationLevel.NO_COMPILATION:
raise ValueError("No compilation level is set.") raise ValueError("No compilation level is set.")
...@@ -628,3 +634,41 @@ class CompilationConfig: ...@@ -628,3 +634,41 @@ class CompilationConfig:
return use_fx_graph_piecewise_compilation or \ return use_fx_graph_piecewise_compilation or \
use_inductor_piecewise_compilation use_inductor_piecewise_compilation
def custom_op_log_check(self):
"""
This method logs the enabled/disabled custom ops and checks that the
passed custom_ops field only contains relevant ops.
It is called at the end of set_current_vllm_config,
after the custom ops have been instantiated.
"""
if len(self.enabled_custom_ops) + len(self.disabled_custom_ops) == 0:
logger.debug("No custom ops found in model.")
return
logger.debug("enabled custom ops: %s", self.enabled_custom_ops)
logger.debug("disabled custom ops: %s", self.disabled_custom_ops)
all_ops_in_model = (self.enabled_custom_ops | self.disabled_custom_ops)
for op in self.custom_ops:
if op in {"all", "none"}:
continue
assert op[0] in {'+', '-'}, "Invalid custom op syntax " \
"(should be checked during init)"
# check if op name exists in model
op_name = op[1:]
if op_name not in all_ops_in_model:
from vllm.model_executor.custom_op import CustomOp
# Does op exist at all or is it just not present in this model?
# Note: Only imported op classes appear in the registry.
missing_str = "doesn't exist (or wasn't imported/registered)" \
if op_name not in CustomOp.op_registry \
else "not present in model"
enable_str = "enabling" if op[0] == '+' else "disabling"
logger.warning_once("Op '%s' %s, %s with '%s' has no effect",
op_name, missing_str, enable_str, op)
...@@ -190,6 +190,7 @@ if TYPE_CHECKING: ...@@ -190,6 +190,7 @@ if TYPE_CHECKING:
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER" VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
VLLM_PATTERN_MATCH_DEBUG: Optional[str] = None
def get_default_cache_root(): def get_default_cache_root():
...@@ -442,6 +443,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -442,6 +443,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_STANDALONE_COMPILE": "VLLM_USE_STANDALONE_COMPILE":
lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1", lambda: os.environ.get("VLLM_USE_STANDALONE_COMPILE", "0") == "1",
# Debug pattern matching inside custom passes.
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
"VLLM_PATTERN_MATCH_DEBUG":
lambda: os.environ.get("VLLM_PATTERN_MATCH_DEBUG", None),
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": "LOCAL_RANK":
......
...@@ -3413,3 +3413,16 @@ def length_from_prompt_token_ids_or_embeds( ...@@ -3413,3 +3413,16 @@ def length_from_prompt_token_ids_or_embeds(
f" prompt_token_ids={prompt_token_len}" f" prompt_token_ids={prompt_token_len}"
f" prompt_embeds={prompt_embeds_len}") f" prompt_embeds={prompt_embeds_len}")
return prompt_token_len return prompt_token_len
@contextlib.contextmanager
def set_env_var(key, value):
old = os.environ.get(key)
os.environ[key] = value
try:
yield
finally:
if old is None:
del os.environ[key]
else:
os.environ[key] = old
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment