Commit 3b2aefb1 authored by zhuwenwen's avatar zhuwenwen
Browse files

remove gemm and moe nn layout

parent b09a0d7b
......@@ -292,7 +292,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/fused_qknorm_rope_kernel.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"
"csrc/quantization/gptq/q_gemm.cu"
......
......@@ -178,8 +178,6 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col);
void cutlass_mla_decode(torch::Tensor const& out, torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
......
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
namespace vllm {
template <typename T>
__global__ void trans_w16_gemm_cudakernel(int64_t num_kernels,T* dst,const T* src,int64_t row,int64_t col)
{
int64_t id = blockIdx.x * blockDim.x + threadIdx.x;
if(id >= num_kernels) return;
int64_t j=id%row;
int64_t i=id/row;
dst[i*row+j]=src[j*col+i];
}
void trans_w16_gemm_cuda(half* dst,const half* src,int64_t row,int64_t col){
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
int64_t num_kernels=row*col;
int block_size=256;
trans_w16_gemm_cudakernel<<<(num_kernels+block_size-1)/block_size,block_size, 0, stream>>>(num_kernels,dst,src,row,col);
}
} // namespace vllm
void trans_w16_gemm(torch::Tensor dst,torch::Tensor src,int64_t row,int64_t col){
const at::cuda::OptionalCUDAGuard device_guard(device_of(src));
vllm::trans_w16_gemm_cuda(
(half*)dst.data_ptr(),
(const half*)src.data_ptr(),
row,
col
);
}
\ No newline at end of file
......@@ -231,10 +231,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor cos_sin_cache, bool is_neox) -> ()");
ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding);
// trans w16
ops.def("trans_w16_gemm(Tensor! dst, Tensor src, int row, int col) -> ()");
ops.impl("trans_w16_gemm", torch::kCUDA, &trans_w16_gemm);
// Quantization ops
#ifndef USE_ROCM
// Quantized GEMM for AWQ.
......
......@@ -415,12 +415,6 @@ def apply_repetition_penalties(
logits, prompt_mask, output_mask, repetition_penalties
)
# trans_w16
def trans_w16_gemm(dst: torch.Tensor, src: torch.Tensor,
row:int, col:int) -> None :
torch.ops._C.trans_w16_gemm(dst,src,row,col)
# fused quant layer norm ops
def rms_norm_dynamic_per_token_quant(
input: torch.Tensor,
......
......@@ -252,7 +252,6 @@ if TYPE_CHECKING:
VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_LOG_MODEL_INSPECTION: bool = False
VLLM_DEBUG_MFU_METRICS: bool = False
VLLM_USE_FLASH_MLA: bool = False
def get_default_cache_root():
......@@ -1612,9 +1611,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DEBUG_MFU_METRICS": lambda: bool(
int(os.getenv("VLLM_DEBUG_MFU_METRICS", "0"))
),
# If set, vLLM will use FLASH MLA attention optimizations.
"VLLM_USE_FLASH_MLA":
lambda: bool(int(os.getenv("VLLM_USE_FLASH_MLA", "0"))),
}
# --8<-- [end:env-vars-definition]
......
......@@ -757,7 +757,6 @@ def invoke_fused_moe_triton_kernel(
per_channel_quant: bool,
block_shape: list[int] | None = None,
B_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
):
assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1
......@@ -794,7 +793,7 @@ def invoke_fused_moe_triton_kernel(
EM = num_tokens * config["BLOCK_SIZE_M"]
grid = lambda META: (
triton.cdiv(EM, META["BLOCK_SIZE_M"])
* triton.cdiv(B.size(1) if not use_nn_moe else B.size(2), META["BLOCK_SIZE_N"]),
* triton.cdiv(B.size(1)),
)
HAS_BIAS = B_bias is not None
......@@ -1031,7 +1030,7 @@ def zero_experts_compute_triton(
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
def get_config_file_name(
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None, use_nn_moe: bool | None = False,
E: int, N: int, dtype: str | None, block_shape: list[int] | None = None,
) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
# Set device_name to H200 if a device from the H200 family is detected
......@@ -1041,10 +1040,7 @@ def get_config_file_name(
block_shape_selector = (
"" if not block_shape or not all(block_shape) else f",block_shape={block_shape}"
).replace(" ", "")
if not use_nn_moe:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
else:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}_nn.json"
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
@functools.lru_cache
......@@ -1054,7 +1050,6 @@ def get_moe_configs(
dtype: str | None,
block_n: int | None = None,
block_k: int | None = None,
use_nn_moe: bool | None = False,
) -> dict[int, Any] | None:
"""
Return optimized configurations for the fused MoE kernel.
......@@ -1072,7 +1067,7 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs
# directory
block_shape = [block_n, block_k] if block_n and block_k else None
json_file_name = get_config_file_name(E, N, dtype, block_shape, use_nn_moe=use_nn_moe)
json_file_name = get_config_file_name(E, N, dtype, block_shape)
config_file_paths = []
......@@ -1246,7 +1241,6 @@ def get_default_config(
topk: int,
dtype: str | None,
block_shape: list[int] | None = None,
use_nn_moe: bool | None =False,
) -> dict[str, int]:
if vllm_is_batch_invariant():
config = {
......@@ -1302,9 +1296,6 @@ def get_default_config(
"GROUP_SIZE_M": 8,
"SPLIT_K": 1,
}
if use_nn_moe:
config["num_ldmatrixes"] = 1
return config
......@@ -1315,7 +1306,6 @@ def try_get_optimal_moe_config(
dtype: str | None,
M: int,
block_shape: list[int] | None = None,
use_nn_moe: bool | None = False,
) -> dict[str, int]:
from vllm.model_executor.layers.fused_moe import get_config
......@@ -1324,15 +1314,12 @@ def try_get_optimal_moe_config(
config = override_config
else:
# First try to load optimal config from the file
if not use_nn_moe:
E, _, N = w2_shape
else:
E, N, _ = w2_shape
if dtype == "int4_w4a16":
N = N * 2
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k, use_nn_moe=use_nn_moe)
configs = get_moe_configs(E, N, dtype, block_n, block_k)
if configs:
# If an optimal configuration map has been found, look up the
......@@ -1340,7 +1327,7 @@ def try_get_optimal_moe_config(
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2] if not use_nn_moe else w1_shape[1], top_k, dtype, block_shape, use_nn_moe=use_nn_moe)
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, block_shape)
return config
......@@ -1738,7 +1725,6 @@ def inplace_fused_experts(
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
) -> None:
fused_experts_impl(
hidden_states,
......@@ -1794,7 +1780,6 @@ def inplace_fused_experts_fake(
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
) -> None:
pass
......@@ -1837,7 +1822,6 @@ def outplace_fused_experts(
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None = False,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
......@@ -1938,7 +1922,6 @@ def fused_experts(
expert_map: torch.Tensor | None = None,
quant_config: FusedMoEQuantConfig | None = None,
allow_deep_gemm: bool = False,
use_nn_moe: bool | None = False,
) -> torch.Tensor:
if quant_config is None:
quant_config = FUSED_MOE_UNQUANTIZED_CONFIG
......@@ -1997,7 +1980,6 @@ def fused_experts(
block_shape=quant_config.block_shape,
w1_bias=quant_config.w1_bias,
w2_bias=quant_config.w2_bias,
use_nn_moe=use_nn_moe,
)
......@@ -2052,13 +2034,10 @@ def fused_experts_impl(
block_shape: list[int] | None = None,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
use_nn_moe: bool | None= False,
) -> torch.Tensor:
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif use_nn_moe:
assert hidden_states.size(1) == w1.size(1), "Hidden size mismatch"
elif ocp_mx_scheme is not None:
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
......@@ -2088,9 +2067,6 @@ def fused_experts_impl(
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
......@@ -2124,7 +2100,6 @@ def fused_experts_impl(
top_k_num,
config_dtype,
block_shape=block_shape,
use_nn_moe=use_nn_moe,
)
config = get_config_func(M)
......@@ -2132,12 +2107,12 @@ def fused_experts_impl(
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
cache13 = torch.empty(
M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
M * top_k_num * max(N, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache1 = cache13[: M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[: M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K if not use_nn_moe else w2.shape[2])
intermediate_cache3 = cache13[: M * top_k_num * K].view(M, top_k_num, K)
# This needs separate memory since it's used concurrently with cache1
activation_out_dim = mk.FusedMoEPermuteExpertsUnpermute.adjust_N_for_activation(
......@@ -2284,7 +2259,6 @@ def fused_experts_impl(
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w1_bias,
use_nn_moe=use_nn_moe,
)
apply_moe_activation(
......@@ -2324,7 +2298,6 @@ def fused_experts_impl(
per_channel_quant=per_channel_quant,
block_shape=block_shape,
B_bias=w2_bias,
use_nn_moe=use_nn_moe,
)
ops.moe_sum(
......
......@@ -657,13 +657,6 @@ class FusedMoE(CustomOp):
"EPLB is only supported for FP8 quantization for now."
)
if quant_config is None:
# Not considering quant for now, temporarily
# self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
self.use_nn_moe = os.environ.get('MOE_NN') == '1'
else:
self.use_nn_moe = False
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
......@@ -671,7 +664,6 @@ class FusedMoE(CustomOp):
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
"global_num_experts": self.global_num_experts,
"use_nn_moe": self.use_nn_moe,
}
# need full intermediate size pre-sharding for WNA16 act order
if self.quant_method.__class__.__name__ in (
......@@ -1046,7 +1038,7 @@ class FusedMoE(CustomOp):
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(
shard_dim if not self.use_nn_moe else ~shard_dim, shard_size * tp_rank, shard_size
shard_dim
)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
......@@ -1056,10 +1048,7 @@ class FusedMoE(CustomOp):
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_w2(
self,
......@@ -1075,13 +1064,10 @@ class FusedMoE(CustomOp):
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(
shard_dim if not self.use_nn_moe else ~shard_dim, shard_size * tp_rank, shard_size
shard_dim
)
# w2, down_proj: Load into only logical weight of w2.
if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_single_value(
self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, expert_id: int
......@@ -1089,10 +1075,7 @@ class FusedMoE(CustomOp):
param_data = param.data
# Input scales can be loaded directly and should be equal.
if not self.use_nn_moe:
param_data[expert_id] = loaded_weight
else:
param_data[expert_id] = loaded_weight.T
def _load_g_idx(
self,
......@@ -1111,10 +1094,7 @@ class FusedMoE(CustomOp):
)
else:
assert shard_id in ("w1", "w3")
if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _map_global_expert_id_to_local_expert_id(self, expert_id: int) -> int:
if self._expert_map is None:
......@@ -1242,7 +1222,7 @@ class FusedMoE(CustomOp):
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False) or self.use_nn_moe
is_transposed = getattr(param, "is_transposed", False)
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
......@@ -2019,7 +1999,6 @@ class FusedMoE(CustomOp):
if do_naive_dispatch_combine
else hidden_states,
router_logits=router_logits,
use_nn_moe=self.use_nn_moe,
)
if has_separate_shared_experts:
......
......@@ -292,14 +292,12 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
return self.forward(
router=router,
layer=layer,
x=x,
router_logits=router_logits,
use_nn_moe=use_nn_moe,
)
def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig:
......@@ -317,7 +315,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
router: FusedMoERouter,
x: torch.Tensor,
router_logits: torch.Tensor,
use_nn_moe: bool | None = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
topk_weights, topk_ids = router.select_experts(
hidden_states=x,
......
......@@ -1445,7 +1445,6 @@ class DeepseekV2ForCausalLM(
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
def set_moe_parameters(self):
self.expert_weights = []
......@@ -1705,35 +1704,6 @@ class DeepseekV2ForCausalLM(
if not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj.weight",
"shared_experts.gate_up_proj.weight",
"shared_experts.down_proj.weight",
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
"lm_head.weight"
]
combined_words = "|".join(lay_key_words)
for layername in loaded_params:
weight = params_dict[layername]
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params
......
......@@ -953,8 +953,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# functools.partial(flash_attn_varlen_func,
# fa_version=self.vllm_flash_attn_version)
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
......@@ -1058,17 +1056,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
if self.use_llama_nn and self.kv_b_proj.quant_method is None:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj)
assert kv_b_proj_weight.shape == (
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
self.kv_lora_rank,), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
else:
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment