Unverified Commit 5569f521 authored by Richard Zou's avatar Richard Zou Committed by GitHub
Browse files

[torch.compile] Stop lazily compiling (#35472)


Signed-off-by: default avatarRichard Zou <zou3519@gmail.com>
parent 138d891d
...@@ -73,6 +73,7 @@ def test_compile_ranges(use_fresh_inductor_cache): ...@@ -73,6 +73,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
Range(start=16, end=16), Range(start=16, end=16),
Range(start=9, end=32), Range(start=9, end=32),
Range(start=64, end=64), Range(start=64, end=64),
Range(start=128, end=128),
Range(start=33, end=8192), Range(start=33, end=8192),
] ]
) )
...@@ -95,16 +96,16 @@ def test_compile_ranges(use_fresh_inductor_cache): ...@@ -95,16 +96,16 @@ def test_compile_ranges(use_fresh_inductor_cache):
with set_current_vllm_config(vllm_config): with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval() model = TestModel(vllm_config=vllm_config, prefix="").eval()
# Number of compilations: 3 for each compile range + 2 compile sizes # Number of compilations: 3 compile ranges + 3 compile sizes
batch_sizes = [1, 4, 16, 24, 48, 64, 8192] batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
with compilation_counter.expect( with compilation_counter.expect(
num_graphs_seen=1, num_graphs_seen=1,
num_piecewise_graphs_seen=1, num_piecewise_graphs_seen=1,
num_backend_compilations=5, num_backend_compilations=6,
): ):
run_model(vllm_config, model, batch_sizes) run_model(vllm_config, model, batch_sizes)
assert post_grad_range_checker.num_calls == 5 assert post_grad_range_checker.num_calls == 6
def test_compile_config_get_compile_ranges(): def test_compile_config_get_compile_ranges():
......
...@@ -109,9 +109,9 @@ def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache): ...@@ -109,9 +109,9 @@ def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
f"got {len(vllm_piecewise_split_graph)}" f"got {len(vllm_piecewise_split_graph)}"
) )
compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start") compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start")
assert len(compile_start_artifacts) == 2, ( assert len(compile_start_artifacts) == 4, (
"Expected 2 vllm_piecewise_compile_start " "Expected 4 vllm_piecewise_compile_start "
"(one for dynamic ranges, one for compile size), " "(2 subgraphs x 2 ranges each: dynamic + compile size), "
f"got {len(compile_start_artifacts)}" f"got {len(compile_start_artifacts)}"
) )
submod_dumps = capture.get("graph_dump", r"vllm_submod_.*") submod_dumps = capture.get("graph_dump", r"vllm_submod_.*")
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast import ast
import contextvars
import dataclasses import dataclasses
import hashlib import hashlib
import json import json
...@@ -18,7 +17,7 @@ from typing import Any ...@@ -18,7 +17,7 @@ from typing import Any
import torch import torch
import torch.fx as fx import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.utils import dynamo_timed
from torch._logging._internal import trace_structured from torch._logging._internal import trace_structured
import vllm.envs as envs import vllm.envs as envs
...@@ -510,9 +509,9 @@ def wrap_with_cudagraph_if_needed( ...@@ -510,9 +509,9 @@ def wrap_with_cudagraph_if_needed(
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`. """Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some It runs the given split graph interpreter, and for each submodule in
submodules specified by `compile_submod_names` with the given `compile_submod_names`, creates a PiecewiseBackend and compiles all
compilation configs. ranges up front.
NOTE: the order in `compile_submod_names` matters, because NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise it will be used to determine the order of the compiled piecewise
...@@ -540,9 +539,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -540,9 +539,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
vllm_backend: "VllmBackend", vllm_backend: "VllmBackend",
) -> None: ) -> None:
super().__init__(module) super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config self.vllm_config = vllm_config
...@@ -552,13 +548,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -552,13 +548,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@instrument(span_name="Inductor compilation") @instrument(span_name="Inductor compilation")
def run(self, *args: Any) -> Any: def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake? return super().run(*args)
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
with self.fake_mode, enable_python_dispatcher():
return super().run(*fake_args)
def call_module( def call_module(
self, self,
...@@ -614,21 +604,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc] ...@@ -614,21 +604,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
model_tag: str = "backbone" model_tag: str = "backbone"
model_is_encoder: bool = False model_is_encoder: bool = False
_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = (
contextvars.ContextVar("on_compilation_complete_callback", default=None)
)
@contextmanager
def set_on_compilation_complete(
callback: Callable[[], None],
) -> Generator[None, None, None]:
token = _on_compilation_complete_callback.set(callback)
try:
yield
finally:
_on_compilation_complete_callback.reset(token)
@contextmanager @contextmanager
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]: def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
...@@ -846,6 +821,7 @@ class VllmBackend: ...@@ -846,6 +821,7 @@ class VllmBackend:
), ),
) )
@dynamo_timed("vllm_backend")
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any: def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
from .caching import ( from .caching import (
VllmSerializableFunction, VllmSerializableFunction,
...@@ -1036,11 +1012,24 @@ class VllmBackend: ...@@ -1036,11 +1012,24 @@ class VllmBackend:
] ]
# propagate the split graph to the piecewise backend, # propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes # compile submodules with symbolic shapes, and compile all ranges
# up front so that compilation is complete before the callable
# is returned.
PiecewiseCompileInterpreter( PiecewiseCompileInterpreter(
self.split_gm, submod_names_to_compile, self.vllm_config, self self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*fake_args) ).run(*fake_args)
# All compilation is done. Save the cache.
time_before_saving = time.perf_counter()
self.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode() fake_mode = detect_fake_mode()
......
...@@ -313,30 +313,26 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] ...@@ -313,30 +313,26 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return fn return fn
# Fall back to standard VllmBackend # Fall back to standard VllmBackend.
# Use a lazy closure: the backend needs traced_files for cache
# dir computation, but those are only populated after
# _verify_source_unchanged runs in decorators.py (which happens
# after deserialization completes).
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
is_encoder = state.get("is_encoder", False) is_encoder = state.get("is_encoder", False)
vllm_backend: VllmBackend = VllmBackend( vllm_config = get_current_vllm_config()
get_current_vllm_config(), state["prefix"], is_encoder compile_inputs = list(state["example_inputs"])
)
def optimized_call(*example_inputs: Any) -> Any: def optimized_call(*example_inputs: Any) -> Any:
""" vllm_backend: VllmBackend = VllmBackend(
On the first run of the optimized call, we rerun the compiler vllm_config, state["prefix"], is_encoder
backend which should result in a cache hit. After the backend )
call returns, we just do a one-time replacement of the optimized
call with the compiled function, so that subsequent calls are on
the AOT compiled path.
"""
compile_inputs = [
inp if inp is not None else example_inputs[i]
for i, inp in enumerate(fn.example_inputs)
]
with tracing(TracingContext(fake_mode)): with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend( fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs state["graph_module"], compile_inputs
).optimized_call ).optimized_call
fn.vllm_backend = vllm_backend
return fn.optimized_call(*example_inputs) return fn.optimized_call(*example_inputs)
fn = cls(**state, optimized_call=optimized_call) fn = cls(**state, optimized_call=optimized_call)
......
...@@ -466,8 +466,12 @@ def _support_torch_compile( ...@@ -466,8 +466,12 @@ def _support_torch_compile(
"Directly load AOT compilation from path %s", aot_compilation_path "Directly load AOT compilation from path %s", aot_compilation_path
) )
# Apply partition wrapper context for proper CUDA graph capture # Apply partition wrapper context for proper CUDA graph capture
from .monitor import end_monitoring_torch_compile
with maybe_use_cudagraph_partition_wrapper(self.vllm_config): with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs)
end_monitoring_torch_compile(self.vllm_config)
return output
if self.compiled: if self.compiled:
assert ( assert (
...@@ -552,18 +556,19 @@ def _support_torch_compile( ...@@ -552,18 +556,19 @@ def _support_torch_compile(
logger.warning("Detected eager backend, disabling AOT compile.") logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False use_aot_compile = False
if use_aot_compile: if use_aot_compile:
from vllm.compilation.backends import set_on_compilation_complete
# store the path for saving after warmup # store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir self._aot_cache_dir = cache_dir
# set callback in context so it's available when compilation completes
with set_on_compilation_complete(self.save_aot_compiled_function):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs) self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
# All compilation is done at this point, save the AOT artifact.
self.save_aot_compiled_function()
output = self.aot_compiled_fn(self, *args, **kwargs) output = self.aot_compiled_fn(self, *args, **kwargs)
else: else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
from .monitor import end_monitoring_torch_compile
end_monitoring_torch_compile(self.vllm_config)
self.compiled = True self.compiled = True
return output return output
......
...@@ -33,7 +33,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None: ...@@ -33,7 +33,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
total_compile_time: float = time.perf_counter() - torch_compile_start_time total_compile_time: float = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE: if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once( logger.info_once(
"torch.compile takes %.2f s in total", "torch.compile and initial profiling run took %.2f s in total",
total_compile_time, total_compile_time,
scope="local", scope="local",
) )
......
...@@ -5,7 +5,6 @@ import dataclasses ...@@ -5,7 +5,6 @@ import dataclasses
import io import io
import json import json
import pickle import pickle
import time
from collections.abc import Callable from collections.abc import Callable
from pickle import Pickler from pickle import Pickler
from typing import Any from typing import Any
...@@ -16,7 +15,6 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner ...@@ -16,7 +15,6 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._logging._internal import trace_structured from torch._logging._internal import trace_structured
from vllm.compilation.backends import VllmBackend from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.config.utils import Range from vllm.config.utils import Range
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -24,6 +22,55 @@ from vllm.logger import init_logger ...@@ -24,6 +22,55 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:
"""Get fake args directly from graph placeholder nodes."""
fake_args = []
for node in graph.graph.nodes:
if node.op == "placeholder":
fake_args.append(node.meta["example_value"])
else:
break
return fake_args
def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
"""Create example inputs with symbolic dims replaced by a concrete size.
Used for single-size eager compilation where we need concrete-shaped
inputs but don't have real runtime tensors yet.
"""
from torch._prims_common import compute_required_storage_length
from torch.fx.experimental.symbolic_shapes import is_symbolic
def concretize(sym_val: Any) -> int:
"""Replace all symbolic variables in a SymInt expression with size."""
if not is_symbolic(sym_val):
return int(sym_val)
expr = sym_val.node.expr
return int(expr.subs({s: size for s in expr.free_symbols}))
args: list[Any] = []
for node in graph.graph.nodes:
if node.op != "placeholder":
break
val = node.meta["example_value"]
if isinstance(val, torch.SymInt):
args.append(concretize(val))
elif isinstance(val, torch.Tensor):
new_shape = tuple(concretize(d) for d in val.shape)
new_strides = tuple(concretize(s) for s in val.stride())
new_storage_offset = concretize(val.storage_offset())
needed_size = compute_required_storage_length(
new_shape, new_strides, new_storage_offset
)
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
t = t.as_strided(new_shape, new_strides, new_storage_offset)
args.append(t)
else:
args.append(val)
return args
@dataclasses.dataclass @dataclasses.dataclass
class RangeEntry: class RangeEntry:
compile_range: Range compile_range: Range
...@@ -109,10 +156,6 @@ class PiecewiseBackend: ...@@ -109,10 +156,6 @@ class PiecewiseBackend:
# the entries for ranges that we need to either # the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {} self.range_entries: dict[Range, RangeEntry] = {}
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly. # We only keep compilation management inside this class directly.
if self.compile_sizes is not None: if self.compile_sizes is not None:
for size in self.compile_sizes: for size in self.compile_sizes:
...@@ -129,7 +172,6 @@ class PiecewiseBackend: ...@@ -129,7 +172,6 @@ class PiecewiseBackend:
self.range_entries[range] = RangeEntry( self.range_entries[range] = RangeEntry(
compile_range=range, compile_range=range,
) )
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges: for range in self.compile_ranges:
self.range_entries[range] = RangeEntry( self.range_entries[range] = RangeEntry(
...@@ -139,12 +181,10 @@ class PiecewiseBackend: ...@@ -139,12 +181,10 @@ class PiecewiseBackend:
# Track whether we've logged the graph for this subgraph (only log once) # Track whether we've logged the graph for this subgraph (only log once)
self._graph_logged = False self._graph_logged = False
# get the on_compilation_complete callback from context... if self.graph is not None:
# PiecewiseBackend is created during the first call, self.compile_all_ranges()
# which is when the context is set (see compilation/decorators.py) else:
from vllm.compilation.backends import _on_compilation_complete_callback self.load_all_ranges()
self.on_compilation_complete = _on_compilation_complete_callback.get()
def get_compiled_graph_wrapper( def get_compiled_graph_wrapper(
self, compiled_graph: Callable[..., Any] self, compiled_graph: Callable[..., Any]
...@@ -161,25 +201,6 @@ class PiecewiseBackend: ...@@ -161,25 +201,6 @@ class PiecewiseBackend:
return compiled_graph_wrapper return compiled_graph_wrapper
def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
time_before_saving = time.perf_counter()
self.vllm_backend.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)
end_monitoring_torch_compile(self.vllm_config)
# Call the completion callback (e.g., to save AOT compiled function)
if self.on_compilation_complete is not None:
self.on_compilation_complete()
def to_bytes(self) -> dict[str, bytes]: def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler): class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj: object) -> Any: def reducer_override(self, obj: object) -> Any:
...@@ -216,27 +237,54 @@ class PiecewiseBackend: ...@@ -216,27 +237,54 @@ class PiecewiseBackend:
return out return out
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]: def compile_all_ranges(self) -> None:
# We need to pass fake example_inputs, otherwise torch.compile """Compile all range entries for this piecewise subgraph up front."""
# will fakify the example_inputs potentially causing some non dynamic assert self.graph is not None, (
# dimension to be be duck shaped to other existing shapes that have hints "Cannot compile without a graph. "
# matching their values. "When loading from cache/AOT artifacts, "
# This is problem because it can lead to unintended specializations! "compile_all_ranges should not be called."
# if the new wrongly dynamic dim is specialized )
# it will force specializing the whole shape
# torch.compile probably should not accept for range_entry in self.range_entries.values():
# non fake tensors as example inputs! if range_entry.compiled:
# See issue https://github.com/vllm-project/vllm/issues/27899 continue
fake_example_inputs = []
assert self.graph is not None self._log_compile_start(range_entry.compile_range)
for node in self.graph.graph.nodes:
# All place holders come first if range_entry.compile_range.is_single_size():
if node.op == "placeholder": args_list = create_concrete_args(
fake_example_inputs.append(node.meta["example_value"]) self.graph, range_entry.compile_range.start
)
else: else:
break args_list = get_fake_args_from_graph(self.graph)
assert len(fake_example_inputs) == len(args)
return fake_example_inputs # TODO(https://github.com/vllm-project/vllm/issues/35766)
# Can we remove strict_autograd_cache and
# force_non_lazy_backward_lowering overrides?
# I added them explicitly because this is what they are
# set to before the refactor
# (https://github.com/vllm-project/vllm/pull/35472).
# They affect the aotautograd cache key computation
# but they shouldn't have any effect on the actual
# compilation.
config_patches = dict(
bundled_autograd_cache=True,
strict_autograd_cache=False,
)
if hasattr(torch._functorch.config, "force_non_lazy_backward_lowering"):
config_patches["force_non_lazy_backward_lowering"] = False
with torch._functorch.config.patch(**config_patches):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
range_entry.compiled = True
def _log_compile_start(self, compile_range: Range): def _log_compile_start(self, compile_range: Range):
"""Log compilation event for TORCH_TRACE/tlparse.""" """Log compilation event for TORCH_TRACE/tlparse."""
...@@ -277,44 +325,29 @@ class PiecewiseBackend: ...@@ -277,44 +325,29 @@ class PiecewiseBackend:
payload_fn=lambda: self.graph.print_readable(print_output=False), payload_fn=lambda: self.graph.print_readable(print_output=False),
) )
def _maybe_compile_for_range_entry( def load_all_ranges(self) -> None:
self, range_entry: RangeEntry, args: tuple[Any, ...] """Load all pre-compiled runnables for this piecewise subgraph.
) -> Any:
if not range_entry.compiled:
if self.compiled_runnables is not None:
range_entry.runnable = self.get_compiled_graph_wrapper(
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
self._log_compile_start(range_entry.compile_range)
# args are real arguments Called during warm start to wrap all cached compiled_runnables
# fakify for range, real args for concrete size. into range_entry.runnable up front, analogous to compile_all_ranges()
# For concrete size, we clear the shape env in for the cold start path.
# compiler_manager.compile() so no need to fakify. """
args_list = ( assert self.compiled_runnables is not None, (
self._fakify_args(args) "load_all_ranges should only be called when compiled_runnables "
if not range_entry.compile_range.is_single_size() "is set (warm start / cache loading path)."
else list(args)
) )
for range_entry in self.range_entries.values():
with ( if range_entry.compiled:
torch._functorch.config.patch("bundled_autograd_cache", True), continue
): key = str(range_entry.compile_range)
range_entry.runnable = self.vllm_backend.compiler_manager.compile( assert key in self.compiled_runnables, (
self.graph, f"Missing compiled runnable for range {range_entry.compile_range}. "
args_list, f"Available keys: {list(self.compiled_runnables.keys())}"
self.vllm_backend.inductor_config, )
self.compilation_config, range_entry.runnable = self.get_compiled_graph_wrapper(
compile_range=range_entry.compile_range, self.compiled_runnables[key]
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
) )
range_entry.compiled = True range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None: def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
# First we try to find the range entry for the concrete compile size # First we try to find the range entry for the concrete compile size
...@@ -338,6 +371,9 @@ class PiecewiseBackend: ...@@ -338,6 +371,9 @@ class PiecewiseBackend:
assert range_entry is not None, ( assert range_entry is not None, (
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}" f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
) )
assert range_entry.compiled, (
self._maybe_compile_for_range_entry(range_entry, args) "All ranges should be compiled or loaded up front in "
"PiecewiseBackend.__init__. "
f"range_entry={range_entry.compile_range}"
)
return range_entry.runnable(*args) return range_entry.runnable(*args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment