Unverified Commit dba4d9de authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[v1][bugfix] fix cudagraph with inplace buffer assignment (#11596)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 32b4c63f
......@@ -28,11 +28,12 @@ class TorchCompileWrapperWithCustomDispatcher:
compiled_callable: Optional[Callable] = None,
compilation_level: int = 0):
vllm_config = get_current_vllm_config()
self.vllm_config = vllm_config
if compiled_callable is None:
# default compilation settings
# compiling the forward method
vllm_config = get_current_vllm_config()
backend = vllm_config.compilation_config.init_backend(vllm_config)
compiled_callable = torch.compile(
......@@ -82,6 +83,13 @@ class TorchCompileWrapperWithCustomDispatcher:
self.compiled_codes.append(new_code)
if self.vllm_config.compilation_config.use_cudagraph and \
"update" in new_code.co_names:
import depyf
src = depyf.decompile(new_code)
msg = "Assigning / modifying buffers of nn.Module during forward pass is not allowed when using cudagraph inside the compiler because it will cause silent errors. Please use eager mode or fix the code. The following code contains clues about which buffer is being modified (please search for the usage of the function `update`):\n" + src # noqa
raise RuntimeError(msg)
@contextmanager
def dispatch_to_code(self, index: int):
"""Context manager to dispatch to the compiled code.
......
......@@ -541,19 +541,12 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)
long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(dtype)
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)
long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
long_short_cache = torch.cat([short_cache, long_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)
......@@ -593,8 +586,6 @@ class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache: torch.Tensor = (
self.long_short_cos_sin_cache.to(idx.device))
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment