Commit 8824ae6a authored by 王敏's avatar 王敏
Browse files

merge 092-dev分支近期修改

parents f9f1887d c0707728
......@@ -171,8 +171,6 @@ void paged_attention_v2_opt_tc_with_mask(
const c10::optional<torch::Tensor>& attn_masks,
const int64_t attn_masks_stride=0);
#ifndef USE_ROCM
void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
......@@ -180,6 +178,8 @@ void merge_attn_states(torch::Tensor& output,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
#ifndef USE_ROCM
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
......
......@@ -216,7 +216,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" int attn_masks_stride) -> ()");
ops.impl("paged_attention_v2_opt_tc_with_mask", torch::kCUDA, &paged_attention_v2_opt_tc_with_mask);
#ifndef USE_ROCM
// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
......@@ -230,6 +229,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states);
#ifndef USE_ROCM
ops.def(
"convert_vertical_slash_indexes("
" Tensor! block_count, Tensor! block_offset, "
......
......@@ -559,10 +559,10 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha is None:
sha = get_sha(vllm_root)
if (major, minor) >= ('2', '5'):
version = 'das.opt1.rc1.' + sha[:7]
version = 'das.opt1.rc2.' + sha[:7]
else:
if (major, minor) >= ('2', '5'):
version = 'das.opt1.rc1'
version = 'das.opt1.rc2'
# dtk version
......
......@@ -8,12 +8,11 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from ..utils import models_path_prefix
from vllm.utils import SUPPORT_TC, gpuname
import vllm.envs as envs
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "distilbert/distilgpt2")])
@pytest.mark.parametrize("block_size", [64] if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else [16])
@pytest.mark.parametrize("block_size", [64] if envs.VLLM_USE_FLASH_ATTN_PA else [16])
def test_computed_prefix_blocks(model: str, block_size: int):
# This test checks if we are able to run the engine to completion
# without triggering asserts.
......
......@@ -14,7 +14,6 @@ from vllm.executor.uniproc_executor import UniProcExecutor
from vllm.sampling_params import SamplingParams
import os
from ..utils import models_path_prefix
from vllm.utils import SUPPORT_TC, gpuname
import vllm.envs as envs
......@@ -60,7 +59,7 @@ def test_custom_executor(model, tmp_path):
model=model,
distributed_executor_backend=CustomUniExecutor,
enforce_eager=True, # reduce test time
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA else 16,
)
engine = LLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
......@@ -84,7 +83,7 @@ def test_custom_executor_async(model, tmp_path):
model=model,
distributed_executor_backend=CustomUniExecutorAsync,
enforce_eager=True, # reduce test time
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA else 16,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
sampling_params = SamplingParams(max_tokens=1)
......@@ -111,7 +110,7 @@ def test_respect_ray(model):
model=model,
distributed_executor_backend="ray",
enforce_eager=True, # reduce test time
block_size=64 if gpuname.startswith('BW') and envs.VLLM_FLASH_ATTN_BACKEND else 16,
block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA else 16,
)
engine = LLMEngine.from_engine_args(engine_args)
assert engine.model_executor.uses_ray
\ No newline at end of file
......@@ -100,29 +100,30 @@ def test_local_workers() -> None:
assert isinstance(e, ChildProcessError)
def test_local_workers_clean_shutdown() -> None:
"""Test clean shutdown"""
# @TODO
# def test_local_workers_clean_shutdown() -> None:
# """Test clean shutdown"""
workers, worker_monitor = _start_workers()
# workers, worker_monitor = _start_workers()
assert worker_monitor.is_alive()
assert all(worker.process.is_alive() for worker in workers)
# assert worker_monitor.is_alive()
# assert all(worker.process.is_alive() for worker in workers)
# Clean shutdown
worker_monitor.close()
# # Clean shutdown
# worker_monitor.close()
worker_monitor.join(20)
# worker_monitor.join(20)
# Ensure everything is stopped
assert not worker_monitor.is_alive()
assert all(not worker.process.is_alive() for worker in workers)
# # Ensure everything is stopped
# assert not worker_monitor.is_alive()
# assert all(not worker.process.is_alive() for worker in workers)
# Further attempts to submit tasks should fail
try:
_result = workers[0].execute_method("worker_method", "test")
pytest.fail("task should fail once workers have been shut down")
except Exception as e:
assert isinstance(e, ChildProcessError)
# # Further attempts to submit tasks should fail
# try:
# _result = workers[0].execute_method("worker_method", "test")
# pytest.fail("task should fail once workers have been shut down")
# except Exception as e:
# assert isinstance(e, ChildProcessError)
@pytest.mark.asyncio
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, List, Optional
import pytest
import os
from vllm import CompletionOutput, LLMEngine, SamplingParams
from ..utils import models_path_prefix
MODEL = os.path.join(models_path_prefix, "meta-llama/llama-2-7b-hf")
MAX_TOKENS = 200
IS_ASYNC = False
@pytest.fixture(scope="session")
def vllm_model(vllm_runner):
with vllm_runner(MODEL) as vllm_model:
yield vllm_model
def _test_stopping(llm_engine: LLMEngine,
expected_output: str,
expected_reason: Any,
stop: Optional[List[str]] = None,
stop_token_ids: Optional[List[int]] = None,
include_in_output: bool = False,
use_async_output_proc: bool = False) -> None:
llm_engine.add_request(
"id", "A story about vLLM:\n",
SamplingParams(
temperature=0.0,
max_tokens=MAX_TOKENS,
stop=stop,
stop_token_ids=stop_token_ids,
include_stop_str_in_output=include_in_output,
), None)
output: Optional[CompletionOutput] = None
output_text = ""
stop_reason = None
if use_async_output_proc:
llm_engine.step()
while llm_engine.has_unfinished_requests():
(request_output, ) = llm_engine.step()
(output, ) = request_output.outputs
# Ensure we don't backtrack
assert output.text.startswith(output_text)
output_text = output.text
stop_reason = output.stop_reason
assert output is not None
assert output_text == expected_output
assert stop_reason == expected_reason
def _set_async_mode(llm_engine, is_async):
llm_engine.scheduler[0].use_async_output_proc = is_async
def _stop_basic(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["."],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".",
use_async_output_proc=is_async)
_test_stopping(llm_engine,
stop=["."],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".",
use_async_output_proc=is_async)
def _stop_multi_tokens(llm_engine, is_async):
_test_stopping(
llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo",
use_async_output_proc=is_async)
_test_stopping(
llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo",
use_async_output_proc=is_async)
def _stop_partial_token(llm_engine, is_async):
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani",
use_async_output_proc=is_async)
_test_stopping(llm_engine,
stop=["gani"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani",
use_async_output_proc=is_async)
def _stop_token_id(llm_engine, is_async):
# token id 13013 => " organization"
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013,
use_async_output_proc=is_async)
_test_stopping(llm_engine,
stop_token_ids=[13013],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013,
use_async_output_proc=is_async)
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_basic(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_basic(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_multi_tokens(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_partial_token(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_partial_token(vllm_model.model.llm_engine, is_async=False)
@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
_set_async_mode(vllm_model.model.llm_engine, True)
_stop_token_id(vllm_model.model.llm_engine, is_async=True)
_set_async_mode(vllm_model.model.llm_engine, False)
_stop_token_id(vllm_model.model.llm_engine, is_async=False)
......@@ -1221,6 +1221,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
return res, s_lse
if not current_platform.is_rocm():
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
......@@ -1238,6 +1239,24 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
block_table=block_table.unsqueeze(0),
return_softmax_lse=True,
)
else:
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
softmax_scale=softmax_scale,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
dtype=torch.int32,
device=query_states.device),
max_seqlen_q=query_states.shape[0],
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
dtype=torch.int32,
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
block_table=block_table.unsqueeze(0),
return_attn_probs=True,
)
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
2).float()
return output, softmax_lse
......
......@@ -1043,9 +1043,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 or not (
if not current_platform.is_rocm():
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
else:
self._pad_v = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
return_softmax_lse, **kwargs):
......
......@@ -7,6 +7,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.two_batch_overlap.v1.two_batch_overlap_v1 import tbo_maybe_save_kv_layer_to_connector
import vllm.envs as envs
from vllm.attention import AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
......@@ -412,6 +413,9 @@ def unified_attention(
output = self.impl.forward(self, query, key, value, kv_cache,
attn_metadata)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
return output
......@@ -457,7 +461,9 @@ def unified_attention_with_output(
attn_metadata,
output=output,
output_scale=output_scale)
if envs.VLLM_ENABLE_TBO:
tbo_maybe_save_kv_layer_to_connector(layer_name, kv_cache)
else:
maybe_save_kv_layer_to_connector(layer_name, kv_cache)
......
......@@ -5,6 +5,7 @@ from typing import Optional
import torch
from vllm.platforms import current_platform
from vllm import envs
def merge_attn_states(
......@@ -31,7 +32,7 @@ def merge_attn_states(
return headdim % 4 == 0
return headdim % 8 == 0
if (current_platform.is_cuda() and supported_dtypes(output)
if (current_platform.is_cuda() or envs.VLLM_USE_MERGE_ATTN_STATES_OPT and supported_dtypes(output)
and supported_headdim(output)):
from vllm._custom_ops import merge_attn_states
return merge_attn_states(output, prefix_output, prefix_lse,
......
......@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
import regex as re
import torch
from vllm import envs
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
......@@ -215,9 +216,18 @@ class P2pNcclConnector(KVConnectorBase_V1):
inject_kv_into_layer(kv_cache_layer, kv_cache,
request.slot_mapping, request.request_id)
tensor = self.p2p_nccl_engine.recv_store.pop(request.request_id + "#" + layer_name, None)
if tensor is not None:
del tensor
tensor_id = request.request_id + "#" + layer_name
if tensor_id in self.p2p_nccl_engine.recv_store:
tensor = self.p2p_nccl_engine.recv_store.pop(tensor_id, None)
self.p2p_nccl_engine.send_request_id_to_tensor_ids.pop(
request.request_id, None)
self.p2p_nccl_engine.recv_request_id_to_tensor_ids.pop(
request.request_id, None)
addr = 0
if isinstance(tensor, tuple):
addr, _, _ = tensor
self.p2p_nccl_engine.pool.free(addr)
def wait_for_layer_load(self, layer_name: str) -> None:
"""Blocking until the KV for a specific layer is loaded into vLLM's
......@@ -258,6 +268,8 @@ class P2pNcclConnector(KVConnectorBase_V1):
Assume the shape of the layer is (2, num_pages, page_size, xxx)
if MLA is not used, and (num_pages, page_size, xxx) otherwise.
"""
if envs.VLLM_ENABLE_TBO:
slot_mapping = slot_mapping.pin_memory().to(device=layer.device, non_blocking=True)
if isinstance(attn_metadata, MLACommonMetadata):
num_pages, page_size = layer.shape[0], layer.shape[1]
return layer.reshape(num_pages * page_size, -1)[slot_mapping,
......
......@@ -326,10 +326,6 @@ class P2pNcclEngine:
# Store Tensor in memory pool
addr = self.pool.store_tensor(tensor)
tensor = (addr, tensor.dtype, tensor.shape)
logger.warning(
"🔴[PUT]Recv Tensor, Out Of Threshold, "
"%s👈%s, data:%s, addr:%d", self.zmq_address,
remote_address.decode(), data, addr)
else:
self.buffer_size += tensor_size
......
......@@ -7,6 +7,7 @@ from collections.abc import Iterable
import torch
import torch.nn.functional as F
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.config import get_current_vllm_config
......@@ -259,7 +260,7 @@ class EPMoE(FusedMoE):
hidden_dim=self.hidden_size,
scale_dim=0,
scale_type_size=vllm_config.model_config.dtype.itemsize,
max_num_inp_token_per_rank=5120,
max_num_inp_token_per_rank=4096,
num_experts_per_rank=self.local_num_experts,
num_experts_per_token=self.top_k,
max_token_type_size=2,
......@@ -294,7 +295,9 @@ class EPMoE(FusedMoE):
dist.barrier()
def forward(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
router_logits: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
return torch.ops.vllm.ep_moe_forward(hidden_states, router_logits,
self.layer_name)
......@@ -318,7 +321,10 @@ class EPMoE(FusedMoE):
]
def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None):
topk_weights, topk_ids = self.select_experts(
hidden_states=hidden_states,
......@@ -334,7 +340,11 @@ class EPMoE(FusedMoE):
indices_type=torch.int64,
routed_scaling_factor=self.routed_scaling_factor,
use_fused_gate=self.use_fused_gate)
if not self.ep_moe_config.moe_shared_expert_overlap and self.shared_experts is not None:
if envs.USE_FUSED_RMS_QUANT:
shared_output, new_resi = self.shared_experts(hidden_states, rms_weight, residual, update_hd=True)
else:
shared_output = self.shared_experts(hidden_states)
topk_ids = topk_ids.to(torch.int32)
......@@ -345,6 +355,8 @@ class EPMoE(FusedMoE):
device=hidden_states.device,
)
#dist.barrier()
(
dispatch_output,
dispatch_weights,
......@@ -360,13 +372,17 @@ class EPMoE(FusedMoE):
#self.sync()
#dispatch_recv_num_token = dispatch_recv_num_token[0].item()
# dispatch_recv_num_token = dispatch_recv_num_token.cpu()[0]
# #dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = dispatch_output[:dispatch_recv_num_token]
# dispatch_weights = dispatch_weights[:dispatch_recv_num_token]
# dispatch_indices = dispatch_indices[:dispatch_recv_num_token]
# dispatch_recv_num_token = dispatch_recv_num_token.item()
# dispatch_output = torch.narrow(dispatch_output, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_weights = torch.narrow(dispatch_weights, dim=0, start=0, length=dispatch_recv_num_token)
# dispatch_indices = torch.narrow(dispatch_indices, dim=0, start=0, length=dispatch_recv_num_token)
# valid_mask = ((dispatch_indices <= 255) & (dispatch_indices >= 0)).all(dim=1)
# dispatch_output = dispatch_output[valid_mask]
......@@ -418,26 +434,31 @@ class EPMoE(FusedMoE):
final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor)
return final_hidden_states
if envs.USE_FUSED_RMS_QUANT:
return final_hidden_states, new_resi
else:
return final_hidden_states, None
def ep_moe_forward(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
layer_name: str, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
assert self.quant_method is not None
return self.forward_impl(hidden_states, router_logits)
return self.forward_impl(hidden_states, router_logits, rms_weight, residual)
def ep_moe_forward_fake(hidden_states: torch.Tensor, router_logits: torch.Tensor,
layer_name: str) -> torch.Tensor:
return torch.empty_like(hidden_states)
layer_name: str, rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(hidden_states), torch.empty_like(hidden_states)
direct_register_custom_op(
op_name="ep_moe_forward",
op_func=ep_moe_forward,
mutates_args=["hidden_states"],
mutates_args=["hidden_states", "router_logits", "rms_weight", "residual"],
fake_impl=ep_moe_forward_fake,
dispatch_key=current_platform.dispatch_key,
tags=(torch.Tag.needs_fixed_stride_order, ),
......
......@@ -234,7 +234,7 @@ def moe_align_block_size(
if envs.VLLM_USE_LIGHT_OP:
op.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
expert_ids, num_tokens_post_pad, None)
else:
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
......
......@@ -33,6 +33,12 @@ from vllm.platforms import current_platform
import os
from vllm.model_executor.utils import gemm_bank_conf
if envs.USE_FUSED_RMS_QUANT:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
logger = init_logger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [
......@@ -327,6 +333,7 @@ class ReplicatedLinear(LinearBase):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -338,6 +345,7 @@ class ReplicatedLinear(LinearBase):
quant_config,
prefix=prefix,
return_bias=return_bias)
self.eps = eps
# All the linear layer supports quant method.
assert self.quant_method is not None
......@@ -385,11 +393,49 @@ class ReplicatedLinear(LinearBase):
param.data.copy_(loaded_weight)
def forward(
self, x: torch.Tensor
self,
input_: torch.Tensor,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
quant_args: Optional[list] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and (rms_weight is not None or quant_args is not None):
if quant_args is not None:
input_quant_args = quant_args
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
else:
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd
)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias, input_quant_args)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias, input_quant_args
else:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, input_, bias)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
......@@ -436,6 +482,7 @@ class ColumnParallelLinear(LinearBase):
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
output_sizes: Optional[list[int]] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
......@@ -459,7 +506,7 @@ class ColumnParallelLinear(LinearBase):
quant_config,
prefix,
return_bias=return_bias)
self.eps = eps
self.gather_output = gather_output
if output_sizes is None:
......@@ -543,10 +590,37 @@ class ColumnParallelLinear(LinearBase):
param.load_column_parallel_weight(loaded_weight=loaded_weight)
def forward(
self, input_
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output:
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else:
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
......@@ -593,6 +667,54 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
return_bias: If true, return bias together with outputs in forward pass.
"""
def forward(
self, input_,
rms_weight: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
update_hd: Optional[bool] = True
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
if envs.USE_FUSED_RMS_QUANT and rms_weight is not None:
input_quant_args = None
assert residual is not None and rms_weight is not None
i_q, _scales = lm_faster_rmsquant(input=input_,
rms_weight=rms_weight,
epsilon=self.eps,
quant_dtype=torch.int8,
residual=residual,
update_input=update_hd)
new_residual = residual
input_quant_args = [i_q, _scales]
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias, input_quant_args)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, new_residual, output_bias
else: # not USE_FUSED_RMS_QUANT
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def __init__(
self,
input_size: int,
......@@ -602,10 +724,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
eps: Optional[float] = 1e-6,
prefix: str = "",
*,
return_bias: bool = True,
):
self.eps = eps
self.output_sizes = output_sizes
tp_size = get_tensor_model_parallel_world_size()
assert all(output_size % tp_size == 0 for output_size in output_sizes)
......@@ -856,7 +980,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset,
shard_size=shard_size)
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
......
......@@ -130,7 +130,7 @@ class AWQConfig(QuantizationConfig):
return "awq"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.half]
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
......@@ -293,7 +293,7 @@ class AWQLinearMethod(LinearMethodBase):
pad_group=2
dim_n = layer.scales.data.shape[1]
dim_k = layer.qweight.data.shape[0]
_qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales,int(group_size))
_qw, _sz=ops.convert_s4(layer.qweight,layer.qzeros,layer.scales.to(torch.float16),int(group_size))
sz = ops.sz_permute(_sz).reshape(-1,dim_n)
sz = sz.reshape(dim_n,-1)
_qw = _qw.reshape(dim_n,-1)
......
......@@ -10,7 +10,8 @@ import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
......@@ -140,6 +141,9 @@ class AWQMarlinConfig(QuantizationConfig):
self.full_config).get_quant_method(layer, prefix)
return AWQMarlinLinearMethod(self)
elif isinstance(layer, FusedMoE):
if is_layer_skipped_awq(
prefix, getattr(self, "modules_to_not_convert", [])):
return UnquantizedFusedMoEMethod(layer.moe_config)
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
......@@ -436,7 +440,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
# Why does this take the intermediate size for size_k?
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
s=layer.w13_scales.to(torch.float16),
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
......@@ -445,7 +449,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
#replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
s=layer.w2_scales.to(torch.float16),
size_k=layer.intermediate_size_per_partition,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
......
......@@ -7,7 +7,8 @@ import torch
import os
from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported,
UnquantizedFusedMoEMethod)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
......@@ -18,7 +19,8 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.model_executor.layers.fused_moe import fused_experts
from vllm.model_executor.layers.quantization.awq import (
is_layer_skipped_awq)
from lmslim.layers.fused_moe.fuse_moe_int4 import fused_experts_w4a16
os.environ['W4A16_MOE_CUDA'] = os.environ.get('W4A16_MOE_CUDA', '0')
......@@ -139,9 +141,9 @@ class MoeWNA16Config(QuantizationConfig):
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
elif isinstance(layer, LinearBase):
# Avoid circular import
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import (
......@@ -167,6 +169,9 @@ class MoeWNA16Config(QuantizationConfig):
else:
raise ValueError("moe_wna16 only support gptq and awq.")
elif isinstance(layer, FusedMoE):
if is_layer_skipped_awq(
prefix, getattr(self, "modules_to_not_convert", [])):
return UnquantizedFusedMoEMethod(layer.moe_config)
return MoeWNA16Method(self)
return None
......
......@@ -21,6 +21,8 @@ from vllm.utils import W8a8GetCacheJSON
import os
from vllm import _custom_ops as ops
from vllm import envs
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8_ep
except Exception:
......@@ -156,7 +158,12 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_quant_args: Optional[list[torch.Tensor]] = None
):
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_q, x_scale = input_quant_args
else:
x_q, x_scale = per_token_quant_int8(x)
if self.w8a8_strategy==1:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment