"vscode:/vscode.git/clone" did not exist on "07286ec5a6edf3a7a3623ddfc8db6a24b1d3c70a"
Commit 8f80b711 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev_yql_marlin' into v0.9.2-dev

parents cd3ed273 9820d063
...@@ -10,16 +10,16 @@ import vllm.envs as envs ...@@ -10,16 +10,16 @@ import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType from vllm.scalar_type import ScalarType
from vllm.utils import direct_register_custom_op
try: try:
from lmslim import quant_ops from lmslim import quant_ops
from lmslim import quant_tools from lmslim import quant_tools
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try: try:
import marlin import lightop
except Exception: except Exception:
print("INFO: Please install marlin if you want to infer awq of marlin.\n") print("INFO: Please install lightop if you want to infer awq of marlin.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -766,6 +766,14 @@ def awq_gemm(input: torch.Tensor, weight: torch.Tensor, ...@@ -766,6 +766,14 @@ def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
splikspace, splikspace,
splikspacesize) splikspacesize)
def awq_gemm_fake(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int,
group_size:int,padding_group:int,splikspace:torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return torch.empty((m, n), dtype=input.dtype, device=input.device)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor, def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int): group_size: int):
return quant_ops.convert_s4(qw,qz,s,group_size) return quant_ops.convert_s4(qw,qz,s,group_size)
...@@ -1477,7 +1485,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, ...@@ -1477,7 +1485,7 @@ def awq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor,
device=b_q_weight.device, device=b_q_weight.device,
dtype=b_q_weight.dtype) dtype=b_q_weight.dtype)
for e in range(num_experts): for e in range(num_experts):
output[e] = torch.ops.marlin.awq_marlin_repack(b_q_weight[e], size_k, output[e] = lightop.awq_marlin_repack(b_q_weight[e], size_k,
size_n, num_bits) size_n, num_bits)
return output return output
...@@ -2437,3 +2445,10 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"): ...@@ -2437,3 +2445,10 @@ if hasattr(torch.ops._C, "int8_scaled_mm_with_quant"):
M = mat1.size(0) M = mat1.size(0)
N = mat2.size(0) N = mat2.size(0)
return torch.empty((M, N), dtype=out_dtype) return torch.empty((M, N), dtype=out_dtype)
direct_register_custom_op(
op_name="awq_gemm",
op_func=awq_gemm,
mutates_args=[],
fake_impl=awq_gemm_fake,
)
\ No newline at end of file
...@@ -6,9 +6,10 @@ from typing import Optional ...@@ -6,9 +6,10 @@ from typing import Optional
import torch import torch
try: try:
import marlin import lightop
except Exception: except Exception:
print("INFO: Please install marlin if you want to infer awq moe of marlin.\n") print("INFO: Please install lightop if you want to infer awq of marlin.\n")
import vllm.envs as envs import vllm.envs as envs
import vllm._custom_ops as ops import vllm._custom_ops as ops
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
...@@ -28,8 +29,8 @@ def fused_marlin_moe( ...@@ -28,8 +29,8 @@ def fused_marlin_moe(
hidden_states: torch.Tensor, # 32, 7168 hidden_states: torch.Tensor, # 32, 7168
w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256 w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2: torch.Tensor, # 256, 256, 7168 w2: torch.Tensor, # 256, 256, 7168
w1_scale: torch.Tensor, w1_scale_zero: torch.Tensor,
w2_scale: torch.Tensor, w2_scale_zero: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
...@@ -41,7 +42,7 @@ def fused_marlin_moe( ...@@ -41,7 +42,7 @@ def fused_marlin_moe(
sort_indices2: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None,
# workspace: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None,
num_bits: int = 4, num_bits: int = 4,
is_k_full: bool = True, is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor: inplace: bool = False) -> torch.Tensor:
...@@ -94,48 +95,31 @@ def fused_marlin_moe( ...@@ -94,48 +95,31 @@ def fused_marlin_moe(
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float16, torch.bfloat16] assert hidden_states.dtype in [torch.float16, torch.bfloat16]
# assert num_bits in [4, 8] # assert num_bits in [4]
# 目前只支持 uint4的量化结果
assert num_bits in [4] assert num_bits in [4]
M, K = hidden_states.shape # 32, 7168 num_tokens, K = hidden_states.shape # 32, 7168
E = w1.shape[0] # 256 E = w1.shape[0] # 256
N = w2.shape[1] * 16 # 256 N = w2.shape[1] * 16 # 256
topk = topk_ids.shape[1] # 8 topk = topk_ids.shape[1] # 8
# # 计算 topk_weights 和 topk_ids
# topk_weights, topk_ids = fused_topk(hidden_states, score, topk, False)
# 选择 block_size_m 的逻辑按照 Marlin来设置 #暂时固定为16384
for block_size_m in [16, 32, 48, 64, 80]: CHUNK_SIZE = 16384
if M * topk / E / block_size_m < 0.9:
break
# print("m: ", M, "; block_m: ", block_size_m)
if global_num_experts == -1: M = min(num_tokens, CHUNK_SIZE)
global_num_experts = E
sorted_token_ids, expert_ids, num_tokens_post_padded = \ if workspace is None:
moe_align_block_size(topk_ids, block_size_m, global_num_experts, sms = torch.cuda.get_device_properties(device='cuda').multi_processor_count
expert_map) workspace = torch.zeros(sms * 3,
# max_num = num_tokens_post_padded.item() dtype=torch.int,
# print("max_num: ", max_num) device=hidden_states.device,
# 输出 requires_grad=False)
# for i in range(0, max_num, block_size_m):
# print(i / block_size_m, sorted_token_ids[i:(i + block_size_m)])
# if workspace is None:
# max_workspace_size = (max(2 * N, K) // 64) * \
# (sorted_token_ids.size(0) // block_size_m)
# device = hidden_states.device
# sms = torch.cuda.get_device_properties(device).multi_processor_count
# max_workspace_size = min(max_workspace_size, sms * 4)
# workspace = torch.zeros(max_workspace_size,
# dtype=torch.int,
# device=device,
# requires_grad=False)
scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None) scalar_type1 = get_scalar_type(num_bits, w1_zeros is not None)
scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None) scalar_type2 = get_scalar_type(num_bits, w2_zeros is not None)
if global_num_experts == -1:
intermediate_cache2 = torch.empty( # [32*8, 256] global_num_experts = E
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N), (M * topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
...@@ -145,64 +129,89 @@ def fused_marlin_moe( ...@@ -145,64 +129,89 @@ def fused_marlin_moe(
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype, dtype=hidden_states.dtype,
) )
intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N] # [32*8, 512] intermediate_cache1 = intermediate_cache13[:M * topk_ids.shape[1] * 2 * N]
intermediate_cache1 = intermediate_cache1.view(-1, 2 * N) intermediate_cache1 = intermediate_cache1.view(-1, 2 * N)
intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K] # # [32*8, 7168] intermediate_cache3 = intermediate_cache13[:M * topk_ids.shape[1] * K]
intermediate_cache3 = intermediate_cache3.view(-1, K) intermediate_cache3 = intermediate_cache3.view(-1, K)
use_atomic_add = hidden_states.dtype == torch.half or \ use_atomic_add = hidden_states.dtype == torch.half or \
torch.cuda.get_device_capability(hidden_states.device)[0] >= 9 torch.cuda.get_device_capability(hidden_states.device)[0] >= 9
intermediate_cache1.zero_() if inplace:
intermediate_cache1 = torch.ops.marlin.moe_wna16_marlin_gemm( out_hidden_states = hidden_states
hidden_states, # [32, 7168] # arg0: torch.Tensor, else:
intermediate_cache1, # [32*8, 512] # arg1: Optional[torch.Tensor] out_hidden_states = torch.empty_like(hidden_states)
w1, # arg2: torch.Tensor
w1_scale, # arg3: torch.Tensor for chunk in range((num_tokens // CHUNK_SIZE) + 1):
# w1_zeros, # arg4: Optional[torch.Tensor]
g_idx1, # arg5: Optional[torch.Tensor] begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE,
sort_indices1, # arg6: Optional[torch.Tensor] min((chunk + 1) * CHUNK_SIZE,
# workspace, # arg7: torch.Tensor num_tokens))
sorted_token_ids, # arg8: torch.Tensor curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
expert_ids, # arg9: torch.Tensor tokens_in_chunk, _ = curr_hidden_states.size()
num_tokens_post_padded, # arg10: torch.Tensor
topk_weights, #arg11: torch.Tensor, if tokens_in_chunk == 0:
block_size_m,# arg12: int, break
topk, # arg13: int, intermediate_cache3 = intermediate_cache3.view(-1, K)
False, # arg14: bool, if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
expert_map is not None, # arg15: bool, intermediate_cache1 = intermediate_cache1[:tokens_in_chunk * topk, :]
scalar_type1.id, # arg16: int intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk, :]
M, # arg17: int, intermediate_cache3 = intermediate_cache3[:tokens_in_chunk * topk, :]
2 * N, # arg18: int M = tokens_in_chunk
K, # arg19: int,
is_k_full, # arg20: bool, # Select block_size_m
use_atomic_add, # arg21: bool, for block_size_m in [16, 32, 48, 64, 80]:
True, # arg22: bool if M * topk / E / block_size_m < 0.9:
break
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(curr_topk_ids, block_size_m, global_num_experts, expert_map)
intermediate_cache1 = lightop.moe_marlin_w4a16(
curr_hidden_states,
intermediate_cache1,
w1,
w1_scale_zero,
g_idx1,
sort_indices1,
workspace,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
curr_topk_weights,
block_size_m,
topk,
False,
expert_map is not None,
M,
2 * N,
K,
is_k_full,
use_atomic_add,
True,
False False
) # arg23: bool )
# [32*8, 512] --> [32*8, 256] torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1)
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, 2 * N)) intermediate_cache3 = lightop.moe_marlin_w4a16(
intermediate_cache3.zero_() intermediate_cache2,
intermediate_cache3 = torch.ops.marlin.moe_wna16_marlin_gemm( intermediate_cache3,
intermediate_cache2, # [32*8, 256]
intermediate_cache3, # [32*8, 7168]
w2, w2,
w2_scale, w2_scale_zero,
# w2_zeros,
g_idx2, g_idx2,
sort_indices2, sort_indices2,
# workspace, workspace,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
topk_weights, curr_topk_weights,
block_size_m, block_size_m,
1, 1,
True, True,
expert_map is not None, expert_map is not None,
scalar_type2.id,
M * topk, M * topk,
K, K,
N, N,
...@@ -212,19 +221,16 @@ def fused_marlin_moe( ...@@ -212,19 +221,16 @@ def fused_marlin_moe(
False False
).view(-1, topk, K) ).view(-1, topk, K)
output = hidden_states if inplace else torch.empty_like(hidden_states) ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), out_hidden_states[begin_chunk_idx:end_chunk_idx])
# return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
# dim=1, return out_hidden_states
# out=output)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), output)
return output
def fused_marlin_moe_fake( def fused_marlin_moe_fake(
hidden_states: torch.Tensor, # 32, 7168 hidden_states: torch.Tensor, # 32, 7168
w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256 w1: torch.Tensor, # 256, 512, 7168 --> 32*8, 512 --> 32*8, 256
w2: torch.Tensor, # 256, 256, 7168 w2: torch.Tensor, # 256, 256, 7168
w1_scale: torch.Tensor, w1_scale_zero: torch.Tensor,
w2_scale: torch.Tensor, w2_scale_zero: torch.Tensor,
gating_output: torch.Tensor, gating_output: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
...@@ -236,7 +242,7 @@ def fused_marlin_moe_fake( ...@@ -236,7 +242,7 @@ def fused_marlin_moe_fake(
sort_indices2: Optional[torch.Tensor] = None, sort_indices2: Optional[torch.Tensor] = None,
w1_zeros: Optional[torch.Tensor] = None, w1_zeros: Optional[torch.Tensor] = None,
w2_zeros: Optional[torch.Tensor] = None, w2_zeros: Optional[torch.Tensor] = None,
# workspace: Optional[torch.Tensor] = None, workspace: Optional[torch.Tensor] = None,
num_bits: int = 4, num_bits: int = 4,
is_k_full: bool = True, is_k_full: bool = True,
inplace: bool = False) -> torch.Tensor: inplace: bool = False) -> torch.Tensor:
......
...@@ -67,8 +67,14 @@ def default_execution(k,n): ...@@ -67,8 +67,14 @@ def default_execution(k,n):
def getspec_config(M,N,K): def getspec_config(M,N,K):
if f"{M}_{N}_{K}" in triton_configs_dict: m_config = M
return triton_configs_dict[f"{M}_{N}_{K}"] if M > 16:
# 直接计算 2 的幂
m_config = 1
while m_config < M:
m_config *= 2
if f"{m_config}_{N}_{K}" in triton_configs_dict:
return triton_configs_dict[f"{m_config}_{N}_{K}"]
else: else:
return None return None
...@@ -336,14 +342,11 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -336,14 +342,11 @@ class AWQLinearMethod(LinearMethodBase):
padding_group=0 padding_group=0
if envs.VLLM_USE_TRITON_AWQ: if envs.VLLM_USE_TRITON_AWQ:
if m>16:
m = 1 << (m - 1).bit_length()
best_config=getspec_config(m,n,k) best_config=getspec_config(m,n,k)
out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config) out = awq_gemm_triton(reshaped_x, qweight, scales, qzeros, pack_factor, best_config)
out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, )) out_shape = (x.shape[:-1] + (qweight.shape[1] * 8, ))
else: else:
out = ops.awq_gemm(reshaped_x, out = torch.ops.vllm.awq_gemm(reshaped_x,
qweight, qweight,
zeros_and_scales, zeros_and_scales,
m, m,
......
...@@ -401,7 +401,7 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -401,7 +401,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qzeros, extra_weight_attrs) set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace_new(device, 4) layer.workspace = marlin_make_workspace_new(device, 3)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0] num_experts = layer.w13_qweight.shape[0]
...@@ -546,6 +546,6 @@ class AWQMoEMethod(FusedMoEMethodBase): ...@@ -546,6 +546,6 @@ class AWQMoEMethod(FusedMoEMethodBase):
expert_map=expert_map, expert_map=expert_map,
w1_zeros=layer.w13_qzeros, w1_zeros=layer.w13_qzeros,
w2_zeros=layer.w2_qzeros, w2_zeros=layer.w2_qzeros,
# workspace=layer.workspace workspace=layer.workspace,
num_bits=4 num_bits=4
) )
...@@ -294,7 +294,7 @@ def awq_gemm_triton(input: torch.Tensor, ...@@ -294,7 +294,7 @@ def awq_gemm_triton(input: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
qzeros: torch.Tensor, qzeros: torch.Tensor,
split_k_iters: int, split_k_iters: int,
config) -> torch.Tensor: config=None) -> torch.Tensor:
M, K = input.shape M, K = input.shape
N = qweight.shape[1] * 8 N = qweight.shape[1] * 8
group_size = qweight.shape[0] // qzeros.shape[0] group_size = qweight.shape[0] // qzeros.shape[0]
......
...@@ -14,10 +14,6 @@ from vllm.platforms import current_platform ...@@ -14,10 +14,6 @@ from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types from vllm.scalar_type import ScalarType, scalar_types
from .quant_utils import pack_cols, unpack_cols from .quant_utils import pack_cols, unpack_cols
try:
import marlin
except Exception:
print("INFO: Please install marlin if you want to infer awq moe of marlin.\n")
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( ...@@ -21,7 +21,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.compilation.decorators import support_torch_compile
from .deepseek_v2 import (DeepseekV2DecoderLayer, from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name) get_spec_layer_idx_from_weight_name)
from .interfaces import SupportsPP from .interfaces import SupportsPP
...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module): ...@@ -150,7 +150,7 @@ class DeepSeekMultiTokenPredictor(nn.Module):
sampling_metadata) sampling_metadata)
return logits return logits
#@support_torch_compile
class DeepSeekMTP(nn.Module, SupportsPP): class DeepSeekMTP(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment