Unverified Commit e6ae4b1b authored by Zhengxu Chen's avatar Zhengxu Chen Committed by GitHub
Browse files

[compile] Enable mega aot artifact for torch 2.12+. (#37198)


Signed-off-by: default avatarzhxchen17 <zhxchen17@fb.com>
parent 2dccb38f
...@@ -307,13 +307,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -307,13 +307,6 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
num_submods = len(submod_names) num_submods = len(submod_names)
num_artifacts = standalone_compile_artifacts.num_artifacts() num_artifacts = standalone_compile_artifacts.num_artifacts()
logger.info(
"reconstructing serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
with functorch_ctx: with functorch_ctx:
fn = reconstruct_serializable_fn_from_mega_artifact( fn = reconstruct_serializable_fn_from_mega_artifact(
state=state, state=state,
...@@ -324,7 +317,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -324,7 +317,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
) )
logger.info( logger.info(
"reconstructed serializable fn from standalone compile artifacts" "reconstructed serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
) )
return fn return fn
......
...@@ -296,6 +296,16 @@ def use_aot_compile() -> bool: ...@@ -296,6 +296,16 @@ def use_aot_compile() -> bool:
) )
def use_mega_aot_artifact():
from vllm.utils.torch_utils import is_torch_equal_or_newer
default_value = (
"1" if is_torch_equal_or_newer("2.12.0.dev") and use_aot_compile() else "0"
)
return os.environ.get("VLLM_USE_MEGA_AOT_ARTIFACT", default_value) == "1"
def env_with_choices( def env_with_choices(
env_name: str, env_name: str,
default: str | None, default: str | None,
...@@ -616,10 +626,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -616,10 +626,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Enable loading compiled models directly from cached standalone compile artifacts # Enable loading compiled models directly from cached standalone compile artifacts
# without re-splitting graph modules. This reduces overhead during model # without re-splitting graph modules. This reduces overhead during model
# loading by using reconstruct_serializable_fn_from_mega_artifact. # loading by using reconstruct_serializable_fn_from_mega_artifact.
"VLLM_USE_MEGA_AOT_ARTIFACT": lambda: os.environ.get( "VLLM_USE_MEGA_AOT_ARTIFACT": use_mega_aot_artifact,
"VLLM_USE_MEGA_AOT_ARTIFACT", "0"
)
== "1",
# local rank of the process in the distributed setting, used to determine # local rank of the process in the distributed setting, used to determine
# the GPU device id # the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")), "LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment