"docs/vscode:/vscode.git/clone" did not exist on "02f0c7b220422792f5e53de2a7d51d2d3ff2df28"
Unverified Commit 715681c1 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[LoRA] Support dual CUDA streams-Linear Layer (#35721)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent dc02271d
...@@ -43,6 +43,13 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool): ...@@ -43,6 +43,13 @@ def cleanup_fixture(should_do_global_cleanup_after_test: bool):
cleanup_dist_env_and_memory(shutdown_ray=True) cleanup_dist_env_and_memory(shutdown_ray=True)
@pytest.fixture
def maybe_enable_lora_dual_stream(monkeypatch: pytest.MonkeyPatch):
if current_platform.is_cuda():
monkeypatch.setenv("VLLM_LORA_ENABLE_DUAL_STREAM", "1")
yield
@pytest.fixture @pytest.fixture
def dist_init(): def dist_init():
from tests.utils import ensure_current_vllm_config from tests.utils import ensure_current_vllm_config
......
...@@ -521,8 +521,10 @@ def test_linear_replicated( ...@@ -521,8 +521,10 @@ def test_linear_replicated(
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config) punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper) assert check_punica_wrapper(punica_wrapper)
def create_random_linear_replicated_layer(): def create_random_linear_replicated_layer(idx: int = 0):
linear = ReplicatedLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear = ReplicatedLinear(
4096, 4096, bias=False, params_dtype=torch.float16, prefix=f"layer_{idx}"
)
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ReplicatedLinearWithLoRA(linear) lora_linear = ReplicatedLinearWithLoRA(linear)
...@@ -539,7 +541,7 @@ def test_linear_replicated( ...@@ -539,7 +541,7 @@ def test_linear_replicated(
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_replicated_layer() linear, lora_linear = create_random_linear_replicated_layer(i)
assert torch.equal(linear.weight, lora_linear.weight) assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
...@@ -629,10 +631,14 @@ def test_linear_parallel( ...@@ -629,10 +631,14 @@ def test_linear_parallel(
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config) punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper) assert check_punica_wrapper(punica_wrapper)
def create_random_linear_parallel_layer(): def create_random_linear_parallel_layer(idx: int = 0):
if orientation == "row": if orientation == "row":
linear = RowParallelLinear( linear = RowParallelLinear(
4096, 4096, bias=False, params_dtype=torch.float16 4096,
4096,
bias=False,
params_dtype=torch.float16,
prefix=f"layer_{idx}",
) )
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ( lora_linear = (
...@@ -642,7 +648,11 @@ def test_linear_parallel( ...@@ -642,7 +648,11 @@ def test_linear_parallel(
) )
else: else:
linear = ColumnParallelLinear( linear = ColumnParallelLinear(
4096, 4096, bias=False, params_dtype=torch.float16 4096,
4096,
bias=False,
params_dtype=torch.float16,
prefix=f"layer_{idx}",
) )
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ( lora_linear = (
...@@ -664,7 +674,7 @@ def test_linear_parallel( ...@@ -664,7 +674,7 @@ def test_linear_parallel(
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_random_linear_parallel_layer() linear, lora_linear = create_random_linear_parallel_layer(i)
assert torch.equal(linear.weight, lora_linear.weight) assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, _ = populate_loras( lora_dict, _ = populate_loras(
...@@ -754,10 +764,14 @@ def test_column_parallel_packed( ...@@ -754,10 +764,14 @@ def test_column_parallel_packed(
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config) punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper) assert check_punica_wrapper(punica_wrapper)
def create_column_parallel_packed_layer(): def create_column_parallel_packed_layer(idx: int = 0):
if repeats == 2: if repeats == 2:
linear = MergedColumnParallelLinear( linear = MergedColumnParallelLinear(
4096, [4096] * repeats, bias=False, params_dtype=torch.float16 4096,
[4096] * repeats,
bias=False,
params_dtype=torch.float16,
prefix=f"layer_{idx}",
) )
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ( lora_linear = (
...@@ -767,7 +781,12 @@ def test_column_parallel_packed( ...@@ -767,7 +781,12 @@ def test_column_parallel_packed(
) )
elif repeats == 3: elif repeats == 3:
linear = QKVParallelLinear( linear = QKVParallelLinear(
4096, 64, 32, bias=False, params_dtype=torch.float16 4096,
64,
32,
bias=False,
params_dtype=torch.float16,
prefix=f"layer_{idx}",
) )
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ( lora_linear = (
...@@ -777,7 +796,12 @@ def test_column_parallel_packed( ...@@ -777,7 +796,12 @@ def test_column_parallel_packed(
) )
else: else:
linear = QKVParallelLinear( linear = QKVParallelLinear(
4096, 64, 32, bias=False, params_dtype=torch.float16 4096,
64,
32,
bias=False,
params_dtype=torch.float16,
prefix=f"layer_{idx}",
) )
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
lora_linear = ( lora_linear = (
...@@ -810,7 +834,7 @@ def test_column_parallel_packed( ...@@ -810,7 +834,7 @@ def test_column_parallel_packed(
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_column_parallel_packed_layer() linear, lora_linear = create_column_parallel_packed_layer(i)
assert torch.equal(linear.weight, lora_linear.weight) assert torch.equal(linear.weight, lora_linear.weight)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
lora_dict, sublora_dict = populate_loras( lora_dict, sublora_dict = populate_loras(
...@@ -902,10 +926,14 @@ def test_merged_column_parallel_variable_slice( ...@@ -902,10 +926,14 @@ def test_merged_column_parallel_variable_slice(
output_sizes = [1024 + i * 256 for i in range(num_slices)] output_sizes = [1024 + i * 256 for i in range(num_slices)]
total_output = sum(output_sizes) total_output = sum(output_sizes)
def create_layer(): def create_layer(idx: int = 0):
# Create linear layer # Create linear layer
linear = MergedColumnParallelLinear( linear = MergedColumnParallelLinear(
4096, output_sizes, bias=False, params_dtype=torch.float16 4096,
output_sizes,
bias=False,
params_dtype=torch.float16,
prefix=f"layer_{idx}",
) )
linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data = torch.rand_like(linear.weight.data)
...@@ -917,7 +945,7 @@ def test_merged_column_parallel_variable_slice( ...@@ -917,7 +945,7 @@ def test_merged_column_parallel_variable_slice(
for i in range(NUM_RANDOM_SEEDS): for i in range(NUM_RANDOM_SEEDS):
set_random_seed(i) set_random_seed(i)
id_to_index = get_random_id_to_index(num_loras, max_loras) id_to_index = get_random_id_to_index(num_loras, max_loras)
linear, lora_linear = create_layer() linear, lora_linear = create_layer(i)
lora_linear.set_mapping(punica_wrapper) lora_linear.set_mapping(punica_wrapper)
# Populate LoRA weights # Populate LoRA weights
......
...@@ -110,7 +110,7 @@ def generate_and_test( ...@@ -110,7 +110,7 @@ def generate_and_test(
) )
def test_olmoe_lora(olmoe_lora_files): def test_olmoe_lora(olmoe_lora_files, maybe_enable_lora_dual_stream):
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
# Otherwise, the lora-test will fail due to CUDA OOM. # Otherwise, the lora-test will fail due to CUDA OOM.
llm = vllm.LLM( llm = vllm.LLM(
...@@ -141,7 +141,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files): ...@@ -141,7 +141,9 @@ def test_olmoe_lora_mixed(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None]) generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])
def test_olmoe_lora_mixed_random(olmoe_lora_files, tmp_path): def test_olmoe_lora_mixed_random(
olmoe_lora_files, tmp_path, maybe_enable_lora_dual_stream
):
# Create a dummy LoRA with random weights based on the real one # Create a dummy LoRA with random weights based on the real one
random_lora_path = tmp_path / "random_lora" random_lora_path = tmp_path / "random_lora"
shutil.copytree(olmoe_lora_files, random_lora_path) shutil.copytree(olmoe_lora_files, random_lora_path)
......
...@@ -312,7 +312,9 @@ def _assert_qwen35_text_vl_and_mixed_lora( ...@@ -312,7 +312,9 @@ def _assert_qwen35_text_vl_and_mixed_lora(
@create_new_process_for_each_test() @create_new_process_for_each_test()
def test_qwen35_text_lora(qwen35_text_lora_files, qwen35_vl_lora_files): def test_qwen35_text_lora(
qwen35_text_lora_files, qwen35_vl_lora_files, maybe_enable_lora_dual_stream
):
llm = vllm.LLM( llm = vllm.LLM(
model=MODEL_PATH, model=MODEL_PATH,
max_model_len=4096, max_model_len=4096,
...@@ -335,7 +337,9 @@ def test_qwen35_text_lora(qwen35_text_lora_files, qwen35_vl_lora_files): ...@@ -335,7 +337,9 @@ def test_qwen35_text_lora(qwen35_text_lora_files, qwen35_vl_lora_files):
@multi_gpu_test(num_gpus=4) @multi_gpu_test(num_gpus=4)
def test_qwen35_text_lora_tp4(qwen35_text_lora_files, qwen35_vl_lora_files): def test_qwen35_text_lora_tp4(
qwen35_text_lora_files, qwen35_vl_lora_files, maybe_enable_lora_dual_stream
):
llm = vllm.LLM( llm = vllm.LLM(
model=MODEL_PATH, model=MODEL_PATH,
max_model_len=4096, max_model_len=4096,
......
...@@ -7,8 +7,10 @@ import torch ...@@ -7,8 +7,10 @@ import torch
from pydantic import ConfigDict, Field, model_validator from pydantic import ConfigDict, Field, model_validator
from typing_extensions import Self from typing_extensions import Self
from vllm import envs
from vllm.config.utils import config from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.hashing import safe_hash from vllm.utils.hashing import safe_hash
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -105,7 +107,14 @@ class LoRAConfig: ...@@ -105,7 +107,14 @@ class LoRAConfig:
f"max_cpu_loras ({self.max_cpu_loras}) must be >= " f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
f"max_loras ({self.max_loras})." f"max_loras ({self.max_loras})."
) )
if envs.VLLM_LORA_ENABLE_DUAL_STREAM and not current_platform.is_cuda_alike():
raise ValueError("Dual CUDA streams are only supported on CUDA platforms.")
if envs.VLLM_LORA_ENABLE_DUAL_STREAM and self.fully_sharded_loras:
logger.warning_once(
"fully_sharded_loras isn't compatible with "
"VLLM_LORA_ENABLE_DUAL_STREAM, set VLLM_LORA_ENABLE_DUAL_STREAM=False"
)
envs.VLLM_LORA_ENABLE_DUAL_STREAM = False
return self return self
def verify_with_model_config(self, model_config: ModelConfig): def verify_with_model_config(self, model_config: ModelConfig):
......
...@@ -258,6 +258,7 @@ if TYPE_CHECKING: ...@@ -258,6 +258,7 @@ if TYPE_CHECKING:
VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False
VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32 VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32
VLLM_XPU_ENABLE_XPU_GRAPH: bool = False VLLM_XPU_ENABLE_XPU_GRAPH: bool = False
VLLM_LORA_ENABLE_DUAL_STREAM: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -496,8 +497,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -496,8 +497,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# rocm, cpu] # rocm, cpu]
"VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(), "VLLM_TARGET_DEVICE": lambda: os.getenv("VLLM_TARGET_DEVICE", "cuda").lower(),
# Main CUDA version of vLLM. This follows PyTorch but can be overridden. # Main CUDA version of vLLM. This follows PyTorch but can be overridden.
"VLLM_MAIN_CUDA_VERSION": lambda: os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() "VLLM_MAIN_CUDA_VERSION": lambda: (
or "12.9", os.getenv("VLLM_MAIN_CUDA_VERSION", "").lower() or "12.9"
),
# Controls PyTorch float32 matmul precision mode within vLLM workers. # Controls PyTorch float32 matmul precision mode within vLLM workers.
# Valid options mirror torch.set_float32_matmul_precision # Valid options mirror torch.set_float32_matmul_precision
"VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices( "VLLM_FLOAT32_MATMUL_PRECISION": env_with_choices(
...@@ -517,21 +519,19 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -517,21 +519,19 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU. # If set, `MAX_JOBS` will be reduced to avoid oversubscribing the CPU.
"NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None), "NVCC_THREADS": lambda: os.getenv("NVCC_THREADS", None),
# If set, vllm will use precompiled binaries (*.so) # If set, vllm will use precompiled binaries (*.so)
"VLLM_USE_PRECOMPILED": lambda: os.environ.get("VLLM_USE_PRECOMPILED", "") "VLLM_USE_PRECOMPILED": lambda: (
.strip() os.environ.get("VLLM_USE_PRECOMPILED", "").strip().lower() in ("1", "true")
.lower() or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION"))
in ("1", "true") ),
or bool(os.environ.get("VLLM_PRECOMPILED_WHEEL_LOCATION")),
# If set, skip adding +precompiled suffix to version string # If set, skip adding +precompiled suffix to version string
"VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX": lambda: bool( "VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX": lambda: bool(
int(os.environ.get("VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX", "0")) int(os.environ.get("VLLM_SKIP_PRECOMPILED_VERSION_SUFFIX", "0"))
), ),
# Used to mark that setup.py is running in a Docker build context, # Used to mark that setup.py is running in a Docker build context,
# in order to force the use of precompiled binaries. # in order to force the use of precompiled binaries.
"VLLM_DOCKER_BUILD_CONTEXT": lambda: os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "") "VLLM_DOCKER_BUILD_CONTEXT": lambda: (
.strip() os.environ.get("VLLM_DOCKER_BUILD_CONTEXT", "").strip().lower() in ("1", "true")
.lower() ),
in ("1", "true"),
# CMake build type # CMake build type
# If not set, defaults to "Debug" or "RelWithDebInfo" # If not set, defaults to "Debug" or "RelWithDebInfo"
# Available options: "Debug", "Release", "RelWithDebInfo" # Available options: "Debug", "Release", "RelWithDebInfo"
...@@ -577,10 +577,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -577,10 +577,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
# If true, will load models from ModelScope instead of Hugging Face Hub. # If true, will load models from ModelScope instead of Hugging Face Hub.
# note that the value is true or false, not numbers # note that the value is true or false, not numbers
"VLLM_USE_MODELSCOPE": lambda: os.environ.get( "VLLM_USE_MODELSCOPE": lambda: (
"VLLM_USE_MODELSCOPE", "False" os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true"
).lower() ),
== "true",
# Interval in seconds to log a warning message when the ring buffer is full # Interval in seconds to log a warning message when the ring buffer is full
"VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int( "VLLM_RINGBUFFER_WARNING_INTERVAL": lambda: int(
os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60") os.environ.get("VLLM_RINGBUFFER_WARNING_INTERVAL", "60")
...@@ -601,19 +600,17 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -601,19 +600,17 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Feature flag to enable/disable Inductor standalone compile. # Feature flag to enable/disable Inductor standalone compile.
# In torch <= 2.7 we ignore this flag; in torch >= 2.9 this is # In torch <= 2.7 we ignore this flag; in torch >= 2.9 this is
# enabled by default. # enabled by default.
"VLLM_USE_STANDALONE_COMPILE": lambda: os.environ.get( "VLLM_USE_STANDALONE_COMPILE": lambda: (
"VLLM_USE_STANDALONE_COMPILE", "1" os.environ.get("VLLM_USE_STANDALONE_COMPILE", "1") == "1"
) ),
== "1",
# Inductor's pre-grad passes don't do anything for vLLM. # Inductor's pre-grad passes don't do anything for vLLM.
# The pre-grad passes get run even on cache-hit and negatively impact # The pre-grad passes get run even on cache-hit and negatively impact
# vllm cold compile times by O(1s) # vllm cold compile times by O(1s)
# Can remove this after the following issue gets fixed # Can remove this after the following issue gets fixed
# https://github.com/pytorch/pytorch/issues/174502 # https://github.com/pytorch/pytorch/issues/174502
"VLLM_ENABLE_PREGRAD_PASSES": lambda: os.environ.get( "VLLM_ENABLE_PREGRAD_PASSES": lambda: (
"VLLM_ENABLE_PREGRAD_PASSES", "0" os.environ.get("VLLM_ENABLE_PREGRAD_PASSES", "0") == "1"
) ),
== "1",
# Debug pattern matching inside custom passes. # Debug pattern matching inside custom passes.
# Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3'). # Should be set to the fx.Node name (e.g. 'getitem_34' or 'scaled_mm_3').
"VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get( "VLLM_PATTERN_MATCH_DEBUG": lambda: os.environ.get(
...@@ -656,10 +653,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -656,10 +653,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# API key for vLLM API server # API key for vLLM API server
"VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None), "VLLM_API_KEY": lambda: os.environ.get("VLLM_API_KEY", None),
# Whether to log responses from API Server for debugging # Whether to log responses from API Server for debugging
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE": lambda: os.environ.get( "VLLM_DEBUG_LOG_API_SERVER_RESPONSE": lambda: (
"VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False" os.environ.get("VLLM_DEBUG_LOG_API_SERVER_RESPONSE", "False").lower() == "true"
).lower() ),
== "true",
# S3 access information, used for tensorizer to load model from S3 # S3 access information, used for tensorizer to load model from S3
"S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None), "S3_ACCESS_KEY_ID": lambda: os.environ.get("S3_ACCESS_KEY_ID", None),
"S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None), "S3_SECRET_ACCESS_KEY": lambda: os.environ.get("S3_SECRET_ACCESS_KEY", None),
...@@ -670,11 +666,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -670,11 +666,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
), ),
"VLLM_NO_USAGE_STATS": lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1", "VLLM_NO_USAGE_STATS": lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1",
"VLLM_DO_NOT_TRACK": lambda: ( "VLLM_DO_NOT_TRACK": lambda: (
os.environ.get("VLLM_DO_NOT_TRACK", None) (
or os.environ.get("DO_NOT_TRACK", None) os.environ.get("VLLM_DO_NOT_TRACK", None)
or "0" or os.environ.get("DO_NOT_TRACK", None)
) or "0"
== "1", )
== "1"
),
"VLLM_USAGE_SOURCE": lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"), "VLLM_USAGE_SOURCE": lambda: os.environ.get("VLLM_USAGE_SOURCE", "production"),
# Logging configuration # Logging configuration
# If set to 0, vllm will not configure logging # If set to 0, vllm will not configure logging
...@@ -697,36 +695,40 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -697,36 +695,40 @@ environment_variables: dict[str, Callable[[], Any]] = {
"NO_COLOR": lambda: os.getenv("NO_COLOR", "0") != "0", "NO_COLOR": lambda: os.getenv("NO_COLOR", "0") != "0",
# If set, vllm will log stats at this interval in seconds # If set, vllm will log stats at this interval in seconds
# If not set, vllm will log stats every 10 seconds. # If not set, vllm will log stats every 10 seconds.
"VLLM_LOG_STATS_INTERVAL": lambda: val "VLLM_LOG_STATS_INTERVAL": lambda: (
if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) > 0.0 val
else 10.0, if (val := float(os.getenv("VLLM_LOG_STATS_INTERVAL", "10."))) > 0.0
else 10.0
),
# Trace function calls # Trace function calls
# If set to 1, vllm will trace function calls # If set to 1, vllm will trace function calls
# Useful for debugging # Useful for debugging
"VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")), "VLLM_TRACE_FUNCTION": lambda: int(os.getenv("VLLM_TRACE_FUNCTION", "0")),
# If set, vllm will use flashinfer sampler # If set, vllm will use flashinfer sampler
"VLLM_USE_FLASHINFER_SAMPLER": lambda: bool( "VLLM_USE_FLASHINFER_SAMPLER": lambda: (
int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]) bool(int(os.environ["VLLM_USE_FLASHINFER_SAMPLER"]))
) if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ
if "VLLM_USE_FLASHINFER_SAMPLER" in os.environ else None
else None, ),
# Pipeline stage partition strategy # Pipeline stage partition strategy
"VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None), "VLLM_PP_LAYER_PARTITION": lambda: os.getenv("VLLM_PP_LAYER_PARTITION", None),
# (CPU backend only) CPU key-value cache space. # (CPU backend only) CPU key-value cache space.
# default is None and will be set as 4 GB # default is None and will be set as 4 GB
"VLLM_CPU_KVCACHE_SPACE": lambda: int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0")) "VLLM_CPU_KVCACHE_SPACE": lambda: (
if "VLLM_CPU_KVCACHE_SPACE" in os.environ int(os.getenv("VLLM_CPU_KVCACHE_SPACE", "0"))
else None, if "VLLM_CPU_KVCACHE_SPACE" in os.environ
else None
),
# (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31", # (CPU backend only) CPU core ids bound by OpenMP threads, e.g., "0-31",
# "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'. # "0,1,2", "0-31,33". CPU cores of different ranks are separated by '|'.
"VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"), "VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "auto"),
# (CPU backend only) CPU cores not used by OMP threads . # (CPU backend only) CPU cores not used by OMP threads .
# Those CPU cores will not be used by OMP threads of a rank. # Those CPU cores will not be used by OMP threads of a rank.
"VLLM_CPU_NUM_OF_RESERVED_CPU": lambda: int( "VLLM_CPU_NUM_OF_RESERVED_CPU": lambda: (
os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0") int(os.getenv("VLLM_CPU_NUM_OF_RESERVED_CPU", "0"))
) if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ
if "VLLM_CPU_NUM_OF_RESERVED_CPU" in os.environ else None
else None, ),
# (CPU backend only) whether to use SGL kernels, optimized for small batch. # (CPU backend only) whether to use SGL kernels, optimized for small batch.
"VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))), "VLLM_CPU_SGL_KERNEL": lambda: bool(int(os.getenv("VLLM_CPU_SGL_KERNEL", "0"))),
# (CPU backend only) whether to enable attention spilt KV. # (CPU backend only) whether to enable attention spilt KV.
...@@ -920,9 +922,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -920,9 +922,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# a list of plugin names to load, separated by commas. # a list of plugin names to load, separated by commas.
# if this is not set, it means all plugins will be loaded # if this is not set, it means all plugins will be loaded
# if this is set to an empty string, no plugins will be loaded # if this is set to an empty string, no plugins will be loaded
"VLLM_PLUGINS": lambda: None "VLLM_PLUGINS": lambda: (
if "VLLM_PLUGINS" not in os.environ None
else os.environ["VLLM_PLUGINS"].split(","), if "VLLM_PLUGINS" not in os.environ
else os.environ["VLLM_PLUGINS"].split(",")
),
# a local directory to look in for unrecognized LoRA adapters. # a local directory to look in for unrecognized LoRA adapters.
# only works if plugins are enabled and # only works if plugins are enabled and
# VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled. # VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
...@@ -954,9 +958,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -954,9 +958,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# and performance comparisons. Currently only affects MPLinearKernel # and performance comparisons. Currently only affects MPLinearKernel
# selection # selection
# (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel) # (kernels: MacheteLinearKernel, MarlinLinearKernel, ExllamaLinearKernel)
"VLLM_DISABLED_KERNELS": lambda: [] "VLLM_DISABLED_KERNELS": lambda: (
if "VLLM_DISABLED_KERNELS" not in os.environ []
else os.environ["VLLM_DISABLED_KERNELS"].split(","), if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(",")
),
"VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool( "VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1")) int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1"))
), ),
...@@ -1147,10 +1153,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1147,10 +1153,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
int(os.getenv("VLLM_ENABLE_MOE_DP_CHUNK", "1")) int(os.getenv("VLLM_ENABLE_MOE_DP_CHUNK", "1"))
), ),
# Randomize inputs during dummy runs when using Data Parallel # Randomize inputs during dummy runs when using Data Parallel
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: os.environ.get( "VLLM_RANDOMIZE_DP_DUMMY_INPUTS": lambda: (
"VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0" os.environ.get("VLLM_RANDOMIZE_DP_DUMMY_INPUTS", "0") == "1"
) ),
== "1",
# Strategy to pack the data parallel ranks for Ray. # Strategy to pack the data parallel ranks for Ray.
# Available options: # Available options:
# - "fill": # - "fill":
...@@ -1190,10 +1195,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1190,10 +1195,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MODEL_REDIRECT_PATH", None "VLLM_MODEL_REDIRECT_PATH", None
), ),
# Whether to use atomicAdd reduce in gptq/awq marlin kernel. # Whether to use atomicAdd reduce in gptq/awq marlin kernel.
"VLLM_MARLIN_USE_ATOMIC_ADD": lambda: os.environ.get( "VLLM_MARLIN_USE_ATOMIC_ADD": lambda: (
"VLLM_MARLIN_USE_ATOMIC_ADD", "0" os.environ.get("VLLM_MARLIN_USE_ATOMIC_ADD", "0") == "1"
) ),
== "1",
# Whether to use marlin kernel in mxfp4 quantization method # Whether to use marlin kernel in mxfp4 quantization method
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
os.environ.get("VLLM_MXFP4_USE_MARLIN", None) os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
...@@ -1211,17 +1215,16 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1211,17 +1215,16 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Whether to turn on the outlines cache for V1 # Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in # This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users. # an environment with potentially malicious users.
"VLLM_V1_USE_OUTLINES_CACHE": lambda: os.environ.get( "VLLM_V1_USE_OUTLINES_CACHE": lambda: (
"VLLM_V1_USE_OUTLINES_CACHE", "0" os.environ.get("VLLM_V1_USE_OUTLINES_CACHE", "0") == "1"
) ),
== "1",
# Gap between padding buckets for the forward pass. So we have # Gap between padding buckets for the forward pass. So we have
# 8, we will run forward pass with [16, 24, 32, ...]. # 8, we will run forward pass with [16, 24, 32, ...].
"VLLM_TPU_BUCKET_PADDING_GAP": lambda: int( "VLLM_TPU_BUCKET_PADDING_GAP": lambda: (
os.environ["VLLM_TPU_BUCKET_PADDING_GAP"] int(os.environ["VLLM_TPU_BUCKET_PADDING_GAP"])
) if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ
if "VLLM_TPU_BUCKET_PADDING_GAP" in os.environ else 0
else 0, ),
"VLLM_TPU_MOST_MODEL_LEN": lambda: maybe_convert_int( "VLLM_TPU_MOST_MODEL_LEN": lambda: maybe_convert_int(
os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None) os.environ.get("VLLM_TPU_MOST_MODEL_LEN", None)
), ),
...@@ -1714,6 +1717,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1714,6 +1717,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool( "VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool(
int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0")) int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0"))
), ),
# Whether to enable dual cuda streams for LoRA computation
"VLLM_LORA_ENABLE_DUAL_STREAM": lambda: bool(
int(os.getenv("VLLM_LORA_ENABLE_DUAL_STREAM", "0"))
),
} }
......
...@@ -5,8 +5,15 @@ ...@@ -5,8 +5,15 @@
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.config.lora import LoRAConfig from vllm.config.lora import LoRAConfig
from vllm.distributed.utils import divide from vllm.distributed.utils import divide
from vllm.forward_context import (
ForwardContext,
get_forward_context,
is_forward_context_available,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
LinearBase, LinearBase,
...@@ -14,24 +21,88 @@ from vllm.model_executor.layers.linear import ( ...@@ -14,24 +21,88 @@ from vllm.model_executor.layers.linear import (
RowParallelLinear, RowParallelLinear,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.multi_stream_utils import maybe_execute_in_parallel
from vllm.utils.torch_utils import direct_register_custom_op
from .base import BaseLayerWithLoRA from .base import BaseLayerWithLoRA
from .utils import _get_lora_device from .utils import _get_lora_device
if envs.VLLM_LORA_ENABLE_DUAL_STREAM:
_lora_aux_cuda_stream: torch.cuda.Stream | None = None
def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None:
global _lora_aux_cuda_stream
if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike():
_lora_aux_cuda_stream = torch.cuda.Stream()
return _lora_aux_cuda_stream
def lora_linear_async(
layer_name: str,
output_size: int,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
return self._apply_async_impl(x, bias)
def lora_linear_async_fake(
layer_name: str,
output_size: int,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
# The real function reshapes output back to the original 3D shape
# when the input has an extra batch dimension (transformers backend).
if x.ndim == 3:
return torch.empty(
(x.size(0), x.size(1), output_size),
device=x.device,
dtype=x.dtype,
)
return torch.empty(
(x.size(0), output_size),
device=x.device,
dtype=x.dtype,
)
direct_register_custom_op(
op_name="lora_linear_async",
op_func=lora_linear_async,
fake_impl=lora_linear_async_fake,
)
class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
def __init__(self, base_layer: LinearBase): def __init__(self, base_layer: LinearBase):
super().__init__() super().__init__()
self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM
self.base_layer = base_layer self.base_layer = base_layer
self.input_size = self.base_layer.input_size self.input_size = self.base_layer.input_size
# Ensure tp_size and tp_rank consistency with the base_layer. # Ensure tp_size and tp_rank consistency with the base_layer.
self.tp_size = self.base_layer.tp_size self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(self.base_layer) self.device = _get_lora_device(self.base_layer)
self._init_lora_stream_context()
self.output_slices: tuple[int, ...] self.output_slices: tuple[int, ...]
self.output_size: int self.output_size: int
self.n_slices: int self.n_slices: int
def _init_lora_stream_context(self) -> None:
if not self._enable_aux_cuda_stream:
return
vllm_config = get_current_vllm_config()
self._lora_stream = _get_lora_aux_cuda_stream()
assert current_platform.is_cuda_alike()
self._events = [torch.cuda.Event(), torch.cuda.Event()]
# lora_linear avoids prefix conflicts with the base layer
self.layer_name = self.base_layer.prefix + ".lora_linear_async"
compilation_config = vllm_config.compilation_config
if self.layer_name in compilation_config.static_forward_context:
raise ValueError("Duplicate layer name: {}".format(self.layer_name))
compilation_config.static_forward_context[self.layer_name] = self
def create_lora_weights( def create_lora_weights(
self, self,
max_loras: int, max_loras: int,
...@@ -39,7 +110,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -39,7 +110,6 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
model_config: PretrainedConfig | None = None, model_config: PretrainedConfig | None = None,
) -> None: ) -> None:
self.lora_config = lora_config self.lora_config = lora_config
#
if isinstance(self.base_layer, ReplicatedLinear): if isinstance(self.base_layer, ReplicatedLinear):
lora_a_out_size = lora_config.max_lora_rank lora_a_out_size = lora_config.max_lora_rank
lora_b_out_size = self.output_size lora_b_out_size = self.output_size
...@@ -120,6 +190,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -120,6 +190,18 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
) )
def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor:
# is_forward_context_available for tower modules
if self._enable_aux_cuda_stream and is_forward_context_available():
output_size = sum(self.output_slices)
return torch.ops.vllm.lora_linear_async(
self.layer_name, output_size, x, bias
)
else:
return self._apply_sync(x, bias)
def _apply_sync(
self, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias) output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
original_shape = output.shape if output.ndim == 3 else None original_shape = output.shape if output.ndim == 3 else None
...@@ -144,6 +226,72 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA): ...@@ -144,6 +226,72 @@ class BaseLinearLayerWithLoRA(BaseLayerWithLoRA):
return output return output
def _apply_async_impl(
self, x: torch.Tensor, bias: torch.Tensor | None = None
) -> torch.Tensor:
"""
Forward pass with base linear and LoRA on separate CUDA streams
for overlap, using maybe_execute_in_parallel.
Base layer runs on default stream; LoRA runs on aux stream.
"""
assert envs.VLLM_LORA_ENABLE_DUAL_STREAM
assert x.ndim in (2, 3)
num_tokens = x.size(0) if x.ndim == 2 else x.size(1)
output_size = sum(self.output_slices)
def base_fn() -> torch.Tensor:
return self.base_layer.quant_method.apply(self.base_layer, x, bias)
def lora_fn() -> torch.Tensor:
# Must be zeros, not empty: _lora_expand_kernel exits early (without
# writing) when lora_id == -1 (no active LoRA). If uninitialized,
# output.add_(lora_result) below would corrupt the base output.
lora_output = torch.zeros(
(num_tokens, output_size),
device=self.device,
dtype=x.dtype,
)
# Flatten the batch dimension for the transformers backend
# (which uses shape (1, seq_len, hidden)), matching _apply_sync.
x_2d = x.flatten(0, 1) if x.ndim == 3 else x
self.punica_wrapper.add_lora_linear(
lora_output,
x_2d,
self.lora_a_stacked,
self.lora_b_stacked,
1.0,
self.output_slices,
add_inputs=False,
)
return lora_output
output, lora_result = maybe_execute_in_parallel(
base_fn,
lora_fn,
self._events[0],
self._events[1],
self._lora_stream,
)
original_shape = output.shape if output.ndim == 3 else None
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
output = output.flatten(0, 1)
x = x.flatten(0, 1)
output.add_(lora_result)
# Reshape the flattened output back to its original shape,
# as some MM encoders cannot handle flattened inputs.
if original_shape is not None:
output = output.reshape(original_shape)
return output
@property @property
def weight(self) -> torch.Tensor: def weight(self) -> torch.Tensor:
# unquantizedLinear # unquantizedLinear
......
...@@ -9,8 +9,13 @@ https://arxiv.org/abs/2310.18547 ...@@ -9,8 +9,13 @@ https://arxiv.org/abs/2310.18547
import torch import torch
from vllm import envs
from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_expand_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_b_ptr, get_lora_op_configs from vllm.lora.ops.triton_ops.utils import (
_get_lora_b_ptr,
get_lora_op_configs,
supports_pdl,
)
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -237,9 +242,9 @@ def _lora_expand( ...@@ -237,9 +242,9 @@ def _lora_expand(
NUM_SLICES, NUM_SLICES,
num_active_loras.item(), num_active_loras.item(),
) )
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance. # PDL only works when dual-stream is being used.
use_gdc = False # supports_pdl(inputs.device) use_gdc = supports_pdl(inputs.device) and envs.VLLM_LORA_ENABLE_DUAL_STREAM
_lora_expand_kernel[grid]( _lora_expand_kernel[grid](
inputs, inputs,
lora_ptr_tensor, lora_ptr_tensor,
......
...@@ -9,8 +9,13 @@ https://arxiv.org/abs/2310.18547 ...@@ -9,8 +9,13 @@ https://arxiv.org/abs/2310.18547
import torch import torch
from vllm import envs
from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel from vllm.lora.ops.triton_ops.kernel_utils import do_shrink_kernel
from vllm.lora.ops.triton_ops.utils import _get_lora_a_ptr, get_lora_op_configs from vllm.lora.ops.triton_ops.utils import (
_get_lora_a_ptr,
get_lora_op_configs,
supports_pdl,
)
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
...@@ -220,9 +225,9 @@ def _lora_shrink( ...@@ -220,9 +225,9 @@ def _lora_shrink(
NUM_SLICES, NUM_SLICES,
num_active_loras.item(), num_active_loras.item(),
) )
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance. # PDL only works when dual-stream is being used.
use_gdc = False # supports_pdl(inputs.device) use_gdc = supports_pdl(inputs.device) and envs.VLLM_LORA_ENABLE_DUAL_STREAM
_lora_shrink_kernel[grid]( _lora_shrink_kernel[grid](
inputs, inputs,
lora_ptr_tensor, lora_ptr_tensor,
......
...@@ -144,7 +144,9 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -144,7 +144,9 @@ class PunicaWrapperGPU(PunicaWrapperBase):
x (torch.Tensor): Input tensors x (torch.Tensor): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True. add_inputs (bool): If True, add LoRA output to y; if False, write
LoRA-only output to y (used for dual-stream when base and LoRA
run on different CUDA streams). Defaults to True.
""" """
y_org = y y_org = y
y = y.view(-1, y.shape[-1]) y = y.view(-1, y.shape[-1])
...@@ -161,7 +163,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -161,7 +163,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
num_tokens, self.lora_config.specialize_active_lora num_tokens, self.lora_config.specialize_active_lora
), ),
offset_start=offset_start, offset_start=offset_start,
add_inputs=True, add_inputs=add_inputs,
) )
y = y.view_as(y_org) y = y.view_as(y_org)
...@@ -244,7 +246,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -244,7 +246,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
buffer = torch.empty( buffer = torch.empty(
(len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device (len(output_slices), x.size(0), r), dtype=torch.float32, device=x.device
) )
add_inputs = kwargs.pop("add_inputs", True)
self.add_shrink( self.add_shrink(
buffer, # type: ignore buffer, # type: ignore
x, x,
...@@ -257,7 +259,7 @@ class PunicaWrapperGPU(PunicaWrapperBase): ...@@ -257,7 +259,7 @@ class PunicaWrapperGPU(PunicaWrapperBase):
buffer, # type: ignore buffer, # type: ignore
lora_b_stacked, lora_b_stacked,
output_slices, output_slices,
add_inputs=True, add_inputs=add_inputs,
**kwargs, **kwargs,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment