Unverified Commit 4c47710b authored by Shinichi Hemmi's avatar Shinichi Hemmi Committed by GitHub
Browse files

[CI/Build] Apply ruff formatter to pass pre-commit (#40078)


Signed-off-by: default avatarHemmi Shinichi <shemmi@preferred.jp>
parent bf9a5ddb
......@@ -6,14 +6,13 @@ import torch
import vllm.config
from tests.compile.backend import TestBackend
from vllm.platforms import current_platform
from vllm.compilation.passes.vllm_inductor_pass import (
VllmFusionPatternMatcherPass,
VllmPatternMatcherPass,
VllmPatternReplacement,
)
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.platforms import current_platform
class ReluToAbsPattern(VllmPatternReplacement):
......@@ -58,7 +57,6 @@ class ExpToSqrtPattern(VllmPatternReplacement):
return [self.empty_fp32(4)]
class ReluFusionPass(VllmFusionPatternMatcherPass):
def __init__(self, config: VllmConfig) -> None:
super().__init__(config, "test_relu_fusion")
......@@ -72,13 +70,13 @@ class TwoPatternFusionPass(VllmFusionPatternMatcherPass):
self.register(ExpToSqrtPattern())
@pytest.fixture
def vllm_config():
return VllmConfig(
compilation_config=CompilationConfig(mode=CompilationMode.VLLM_COMPILE),
)
@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Requires CUDA")
def test_register_tracks_patterns(vllm_config):
"""register() appends each VllmPatternReplacement to _pattern_replacements."""
......@@ -96,7 +94,7 @@ def test_uuid_stable(vllm_config):
with vllm.config.set_current_vllm_config(vllm_config):
p1 = ReluFusionPass(vllm_config)
p2 = ReluFusionPass(vllm_config)
p3= TwoPatternFusionPass(vllm_config)
p3 = TwoPatternFusionPass(vllm_config)
assert p1.uuid() == p2.uuid()
assert p1.uuid() != p3.uuid()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment