Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
......@@ -7,6 +7,6 @@ triton==3.1.0
pandas
numpy==1.26.4
tabulate
setuptools>=61
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624
......@@ -2,5 +2,8 @@
-r common.txt
# Dependencies for Neuron devices
packaging>=24.2
setuptools>=77.0.3,<80.0.0
torch-neuronx >= 2.5.0
neuronx-cc
neuronx-cc>=2.0.0a0
torchvision # Required for Llama3.2 multimodal image preprocessing
......@@ -8,12 +8,10 @@ pytest-rerunfailures
pytest-shard
pytest-timeout
librosa # required by audio tests in entrypoints/openai
sentence-transformers
numba == 0.61.2; python_version > '3.9'
# testing utils
awscli
boto3
botocore
datasets
......@@ -24,5 +22,20 @@ runai-model-streamer-s3==0.11.0
tensorizer>=2.9.0
lm-eval==0.4.8
buildkite-test-collector==0.1.9
lm-eval[api]==0.4.8 # required for model evaluation test
# required for quantization test
bitsandbytes>=0.45.3
# required for minicpmo_26 test
vector_quantize_pytorch
vocos
# required for Basic Models Test
blobfile # required for kimi-vl test
matplotlib # required for qwen-vl test
# required for Multi-Modal Models Test (Standard)
num2words # required for smolvlm test
pqdm
timm # required for internvl test
......@@ -2,14 +2,14 @@
-r common.txt
--extra-index-url https://download.pytorch.org/whl/rocm6.2.4
torch==2.6.0
torchvision==0.21.0
torchaudio==2.6.0
torch==2.7.0
torchvision==0.22.0
torchaudio==2.7.0
triton==3.2
cmake>=3.26,<4
packaging
setuptools>=61
packaging>=24.2
setuptools>=77.0.3,<80.0.0
setuptools-scm>=8
wheel
jinja2>=3.1.6
......
# Common dependencies
-r common.txt
# entrypoints test
# librosa==0.10.2.post1 # required by audio tests in entrypoints/openai
......@@ -20,4 +22,10 @@ decord==0.6.0
#sentence-transformers # required by entrypoints/openai/test_score.py
sentence-transformers==3.4.1
# Basic Models Test
matplotlib==3.10.3
# Multi-Modal Models Test (Extended) 3
blobfile==3.0.0
......@@ -5,11 +5,10 @@ numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Req
numba == 0.61.2; python_version > '3.9'
# Dependencies for AMD GPUs
awscli
boto3
botocore
datasets
ray >= 2.10.0
ray>=2.10.0,<2.45.0
peft
pytest-asyncio
tensorizer>=2.9.0
......
......@@ -8,7 +8,6 @@ pytest-shard
pytest-timeout
# testing utils
awscli
backoff # required for phi4mm test
blobfile # required for kimi-vl test
einops # required for MPT, qwen-vl and Mamba
......@@ -23,9 +22,9 @@ sentence-transformers # required for embedding tests
soundfile # required for audio tests
jiwer # required for audio tests
timm # required for internvl test
torch==2.6.0
torchaudio==2.6.0
torchvision==0.21.0
torch==2.7.0
torchaudio==2.7.0
torchvision==0.22.0
transformers_stream_generator # required for qwen-vl test
mamba_ssm # required for plamo2 test
matplotlib # required for qwen-vl test
......
# This file was autogenerated by uv via the following command:
# uv pip compile requirements/test.in -o requirements/test.txt
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128
absl-py==2.1.0
# via rouge-score
accelerate==1.0.1
......@@ -37,8 +37,6 @@ attrs==24.2.0
# referencing
audioread==3.0.1
# via librosa
awscli==1.35.23
# via -r requirements/test.in
backoff==2.2.1
# via
# -r requirements/test.in
......@@ -53,7 +51,6 @@ boto3==1.35.57
# via tensorizer
botocore==1.35.57
# via
# awscli
# boto3
# s3transfer
bounded-pool-executor==0.0.3
......@@ -81,7 +78,6 @@ click==8.1.7
# typer
colorama==0.4.6
# via
# awscli
# sacrebleu
# schemathesis
# tqdm-multiprocess
......@@ -115,8 +111,6 @@ dnspython==2.7.0
# via email-validator
docopt==0.6.2
# via num2words
docutils==0.16
# via awscli
einops==0.8.0
# via
# -r requirements/test.in
......@@ -274,7 +268,7 @@ mamba-ssm==2.2.4
# via -r requirements/test.in
markdown-it-py==3.0.0
# via rich
markupsafe==3.0.2
markupsafe==3.0.1
# via
# jinja2
# werkzeug
......@@ -355,45 +349,48 @@ numpy==1.26.4
# transformers
# tritonclient
# vocos
nvidia-cublas-cu12==12.4.5.8
nvidia-cublas-cu12==12.8.3.14
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.4.127
nvidia-cuda-cupti-cu12==12.8.57
# via torch
nvidia-cuda-nvrtc-cu12==12.8.61
# via torch
nvidia-cuda-nvrtc-cu12==12.4.127
nvidia-cuda-runtime-cu12==12.8.57
# via torch
nvidia-cuda-runtime-cu12==12.4.127
nvidia-cudnn-cu12==9.7.1.26
# via torch
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.3.3.41
# via torch
nvidia-cufft-cu12==11.2.1.3
nvidia-cufile-cu12==1.13.0.11
# via torch
nvidia-curand-cu12==10.3.5.147
nvidia-curand-cu12==10.3.9.55
# via torch
nvidia-cusolver-cu12==11.6.1.9
nvidia-cusolver-cu12==11.7.2.55
# via torch
nvidia-cusparse-cu12==12.3.1.170
nvidia-cusparse-cu12==12.5.7.53
# via
# nvidia-cusolver-cu12
# torch
nvidia-cusparselt-cu12==0.6.2
nvidia-cusparselt-cu12==0.6.3
# via torch
nvidia-nccl-cu12==2.21.5
nvidia-nccl-cu12==2.26.2
# via torch
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvjitlink-cu12==12.8.61
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.4.127
nvidia-nvtx-cu12==12.8.55
# via torch
opencv-python-headless==4.11.0.86
# via
# -r requirements/test.in
# mistral-common
packaging==24.1
packaging==24.2
# via
# accelerate
# black
......@@ -469,8 +466,6 @@ pyarrow==18.0.0
# via
# datasets
# genai-perf
pyasn1==0.6.1
# via rsa
pybind11==2.13.6
# via lm-eval
pycparser==2.22
......@@ -534,7 +529,6 @@ pytz==2024.2
pyyaml==6.0.2
# via
# accelerate
# awscli
# datamodel-code-generator
# datasets
# genai-perf
......@@ -593,16 +587,12 @@ rpds-py==0.20.1
# via
# jsonschema
# referencing
rsa==4.7.2
# via awscli
runai-model-streamer==0.11.0
# via -r requirements/test.in
runai-model-streamer-s3==0.11.0
# via -r requirements/test.in
s3transfer==0.10.3
# via
# awscli
# boto3
# via boto3
sacrebleu==2.4.3
# via lm-eval
safetensors==0.4.5
......@@ -629,11 +619,12 @@ sentence-transformers==3.2.1
# via -r requirements/test.in
sentencepiece==0.2.0
# via mistral-common
setuptools==75.8.0
setuptools==77.0.3
# via
# mamba-ssm
# pytablewriter
# torch
# triton
shellingham==1.5.4
# via typer
six==1.16.0
......@@ -664,7 +655,7 @@ starlette-testclient==0.4.1
# via schemathesis
statsmodels==0.14.4
# via genai-perf
sympy==1.13.1
sympy==1.13.3
# via
# einx
# torch
......@@ -696,7 +687,7 @@ tomli==2.2.1
# via schemathesis
tomli-w==1.2.0
# via schemathesis
torch==2.6.0
torch==2.7.0+cu128
# via
# -r requirements/test.in
# accelerate
......@@ -714,12 +705,12 @@ torch==2.6.0
# torchvision
# vector-quantize-pytorch
# vocos
torchaudio==2.6.0
torchaudio==2.7.0+cu128
# via
# -r requirements/test.in
# encodec
# vocos
torchvision==0.21.0
torchvision==0.22.0+cu128
# via
# -r requirements/test.in
# timm
......@@ -748,7 +739,7 @@ transformers==4.51.3
# transformers-stream-generator
transformers-stream-generator==0.0.5
# via -r requirements/test.in
triton==3.2.0
triton==3.3.0
# via torch
tritonclient==2.51.0
# via
......
......@@ -3,12 +3,13 @@
# Dependencies for TPU
cmake>=3.26
packaging
packaging>=24.2
setuptools-scm>=8
wheel
jinja2>=3.1.6
ray[default]
ray[data]
setuptools==78.1.0
# Install torch_xla
--pre
......@@ -17,9 +18,9 @@ ray[data]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.8.0.dev20250408
torchvision==0.22.0.dev20250408
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch==2.8.0.dev20250430
torchvision==0.22.0.dev20250430
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
......@@ -3,14 +3,14 @@
ray>=2.9
cmake>=3.26
packaging
packaging>=24.2
setuptools-scm>=8
setuptools>=75.8.0
setuptools>=77.0.3,<80.0.0
wheel
jinja2>=3.1.6
datasets # for benchmark scripts
torch==2.6.0+xpu
torch==2.7.0+xpu
torchaudio
torchvision
pytorch-triton-xpu
......@@ -18,6 +18,6 @@ pytorch-triton-xpu
# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu
# FIXME: This will be fix in ipex 2.7. just leave this here for awareness.
# intel-extension-for-pytorch==2.6.10+xpu
oneccl_bind_pt==2.6.0+xpu
intel-extension-for-pytorch==2.7.10+xpu
oneccl_bind_pt==2.7.0+xpu
--extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
......@@ -54,7 +54,7 @@ elif (sys.platform.startswith("linux") and torch.version.cuda is None
# fallback to cpu
VLLM_TARGET_DEVICE = "cpu"
MAIN_CUDA_VERSION = "12.4"
MAIN_CUDA_VERSION = "12.8"
def is_sccache_available() -> bool:
......
......@@ -41,7 +41,7 @@ class MockEngine:
self.abort_request_calls = 0
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
self.parallel_config = ParallelConfig()
self.model_config = MockModelConfig()
async def step_async(self, virtual_engine):
......
......@@ -5,11 +5,13 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import weakref
from unittest.mock import Mock
import pytest
from vllm import LLM
from vllm.platforms import current_platform
from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1
from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
......@@ -152,9 +154,44 @@ def test_models_distributed(
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
check_outputs_equal(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)
def test_failed_model_execution(vllm_runner, monkeypatch) -> None:
from vllm.envs import VLLM_USE_V1
if not VLLM_USE_V1:
pytest.skip("Skipping V0 test, dump input not supported")
# Needed to mock an error in the same process
monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0')
with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model:
if isinstance(vllm_model.model.llm_engine, LLMEngineV1):
v1_test_failed_model_execution(vllm_model)
def v1_test_failed_model_execution(vllm_model):
engine = vllm_model.model.llm_engine
mocked_execute_model = Mock(
side_effect=RuntimeError("Mocked Critical Error"))
engine.engine_core.engine_core.model_executor.execute_model =\
mocked_execute_model
with pytest.raises(RuntimeError) as exc_info:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
vllm_model.generate_greedy(prompts, 200, use_tqdm=False)
assert isinstance(exc_info.value, RuntimeError)
assert "Mocked Critical Error" in str(exc_info.value)
# SPDX-License-Identifier: Apache-2.0
import contextlib
import os
import pytest
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
MODEL = "Qwen/Qwen2-1.5B-Instruct"
@contextlib.contextmanager
def temporary_environ(env_vars):
"""
Temporarily set environment variables and restore them afterward.
We have to do this vs monkeypatch because monkeypatch doesn't work
with "module" scoped fixtures.
"""
original_env = {k: os.environ.get(k) for k in env_vars}
try:
os.environ.update(env_vars)
yield
finally:
for k, v in original_env.items():
if v is None:
os.environ.pop(k, None)
else:
os.environ[k] = v
@pytest.fixture(scope="module")
def full_cudagraph_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.2,
compilation_config=CompilationConfig(full_cuda_graph=True))
@pytest.fixture(scope="module")
def piecewise_llm():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION": "3"
}):
return LLM(model=MODEL,
gpu_memory_utilization=0.5,
compilation_config=CompilationConfig())
def generate_text(llm: LLM, batch_size: int, max_tokens: int):
prompts = ["Hi my name is"] * batch_size
sampling_params = SamplingParams(temperature=0.0,
max_tokens=max_tokens,
top_p=0.95)
return llm.generate(prompts, sampling_params)
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
(16, 10), (25, 10),
(32, 10), (45, 10),
(64, 10), (8, 5),
(8, 20), (8, 200)])
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
piecewise_llm):
"""
Load full cudagraph model and piecewise model once, and at the same time to
reuse them across various test cases.
Test various batch sizes and max_tokens to ensure that the full cudagraph
compilation works for padded cases too.
"""
piecewise_responses = generate_text(piecewise_llm,
batch_size=batch_size,
max_tokens=max_tokens)
full_cudagraph_responses = generate_text(full_cudagraph_llm,
batch_size=batch_size,
max_tokens=max_tokens)
# Check that all responses are the same
for i in range(len(piecewise_responses)):
assert piecewise_responses[i].outputs[
0].text == full_cudagraph_responses[i].outputs[0].text
def test_full_cudagraph_with_invalid_backend():
with temporary_environ({
"VLLM_USE_V1": "1",
"VLLM_FLASH_ATTN_VERSION":
"2" #FA2 not supported with full_cuda_graph
}), pytest.raises(RuntimeError):
LLM(model=MODEL,
compilation_config=CompilationConfig(full_cuda_graph=True))
......@@ -103,7 +103,8 @@ def test_compile_correctness(
method = test_setting.method
fullgraph = test_setting.fullgraph
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
pytest.skip(f"Need exactly {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}")
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
......
......@@ -9,7 +9,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel, PassConfig
from vllm.platforms import current_platform
from ..utils import create_new_process_for_each_test
......@@ -95,9 +95,6 @@ def test_full_graph(
run_model(optimization_level, model, model_kwargs)
PassConfig = CompilationConfig.PassConfig
# TODO(luka) add other supported compilation config scenarios here
@pytest.mark.parametrize(
"compilation_config, model_info",
......
......@@ -5,19 +5,19 @@ import torch
import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey,
kFp8DynamicTokenSym, kFp8StaticTensorSym)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from .backend import TestBackend
OPS_IN_MODEL = [
torch.ops._C.rotary_embedding.default,
torch.ops._C.fused_add_rms_norm.default,
torch.ops._C.silu_and_mul.default,
]
RMS_OP = torch.ops._C.rms_norm.default
......@@ -29,6 +29,9 @@ RMS_QUANT_OPS = {
],
}
SILU_MUL_OP = torch.ops._C.silu_and_mul.default
SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default
prompts = [
"Hello, my name is",
"The president of the United States is",
......@@ -50,13 +53,14 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
torch.set_default_device("cuda")
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config= \
CompilationConfig.PassConfig(enable_fusion=do_fusion,
enable_noop=True))
vllm_config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True))
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)
act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config)
passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass]
passes = [noop_pass, fusion_pass, act_quant_fusion_pass
] if do_fusion else [noop_pass]
func_pass = FixFunctionalizationPass(vllm_config)
backend_func = TestBackend(*passes, func_pass)
backend_no_func = TestBackend(*passes)
......@@ -79,6 +83,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
model_runner.model = torch.compile(orig_model,
fullgraph=True,
backend=backend_no_func)
gen_no_func = llm.generate(prompts, sampling_params)
for output_func, output_no_func in zip(gen_func, gen_no_func):
......@@ -88,7 +93,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey,
# and replaced by fused quantized ops in RMS_QUANT_OPS.
rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)]
] if do_fusion else [RMS_OP]
ops = OPS_IN_MODEL + rms_ops
silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \
quant_key == kFp8StaticTensorSym else [
SILU_MUL_OP
]
ops = OPS_IN_MODEL + rms_ops + silu_mul_ops
for op in ops:
find_auto_fn(backend_no_func.graph_post_pass.nodes, op)
......
......@@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, QuantKey)
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
......@@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = \
CompilationConfig.PassConfig(enable_fusion=True,
enable_noop=True)
PassConfig(enable_fusion=True, enable_noop=True)
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
......
......@@ -22,7 +22,7 @@ def test_bad_callable():
pass_manager.configure(config)
with pytest.raises(AssertionError):
pass_manager.add(simple_callable) # noqa, type wrong on purpose
pass_manager.add(simple_callable)
# Pass that inherits from InductorPass
......
......@@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe,
find_specified_fn_maybe, is_func)
from vllm.compilation.sequence_parallelism import SequenceParallelismPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
VllmConfig)
PassConfig, VllmConfig)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
......@@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int,
# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(
pass_config=CompilationConfig.PassConfig(
enable_sequence_parallelism=True, ), )
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_sequence_parallelism=True))
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
# this is a fake model name to construct the model config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment