Unverified Commit 4c47710b authored by Shinichi Hemmi's avatar Shinichi Hemmi Committed by GitHub
Browse files

[CI/Build] Apply ruff formatter to pass pre-commit (#40078)


Signed-off-by: default avatarHemmi Shinichi <shemmi@preferred.jp>
parent bf9a5ddb
...@@ -6,14 +6,13 @@ import torch ...@@ -6,14 +6,13 @@ import torch
import vllm.config import vllm.config
from tests.compile.backend import TestBackend from tests.compile.backend import TestBackend
from vllm.platforms import current_platform
from vllm.compilation.passes.vllm_inductor_pass import ( from vllm.compilation.passes.vllm_inductor_pass import (
VllmFusionPatternMatcherPass, VllmFusionPatternMatcherPass,
VllmPatternMatcherPass, VllmPatternMatcherPass,
VllmPatternReplacement, VllmPatternReplacement,
) )
from vllm.config import CompilationConfig, CompilationMode, VllmConfig from vllm.config import CompilationConfig, CompilationMode, VllmConfig
from vllm.platforms import current_platform
class ReluToAbsPattern(VllmPatternReplacement): class ReluToAbsPattern(VllmPatternReplacement):
...@@ -58,7 +57,6 @@ class ExpToSqrtPattern(VllmPatternReplacement): ...@@ -58,7 +57,6 @@ class ExpToSqrtPattern(VllmPatternReplacement):
return [self.empty_fp32(4)] return [self.empty_fp32(4)]
class ReluFusionPass(VllmFusionPatternMatcherPass): class ReluFusionPass(VllmFusionPatternMatcherPass):
def __init__(self, config: VllmConfig) -> None: def __init__(self, config: VllmConfig) -> None:
super().__init__(config, "test_relu_fusion") super().__init__(config, "test_relu_fusion")
...@@ -72,13 +70,13 @@ class TwoPatternFusionPass(VllmFusionPatternMatcherPass): ...@@ -72,13 +70,13 @@ class TwoPatternFusionPass(VllmFusionPatternMatcherPass):
self.register(ExpToSqrtPattern()) self.register(ExpToSqrtPattern())
@pytest.fixture @pytest.fixture
def vllm_config(): def vllm_config():
return VllmConfig( return VllmConfig(
compilation_config=CompilationConfig(mode=CompilationMode.VLLM_COMPILE), compilation_config=CompilationConfig(mode=CompilationMode.VLLM_COMPILE),
) )
@pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Requires CUDA") @pytest.mark.skipif(not current_platform.is_cuda_alike(), reason="Requires CUDA")
def test_register_tracks_patterns(vllm_config): def test_register_tracks_patterns(vllm_config):
"""register() appends each VllmPatternReplacement to _pattern_replacements.""" """register() appends each VllmPatternReplacement to _pattern_replacements."""
...@@ -96,7 +94,7 @@ def test_uuid_stable(vllm_config): ...@@ -96,7 +94,7 @@ def test_uuid_stable(vllm_config):
with vllm.config.set_current_vllm_config(vllm_config): with vllm.config.set_current_vllm_config(vllm_config):
p1 = ReluFusionPass(vllm_config) p1 = ReluFusionPass(vllm_config)
p2 = ReluFusionPass(vllm_config) p2 = ReluFusionPass(vllm_config)
p3= TwoPatternFusionPass(vllm_config) p3 = TwoPatternFusionPass(vllm_config)
assert p1.uuid() == p2.uuid() assert p1.uuid() == p2.uuid()
assert p1.uuid() != p3.uuid() assert p1.uuid() != p3.uuid()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment