Unverified Commit 5aa1ebd2 authored by Peng Zhang's avatar Peng Zhang Committed by GitHub
Browse files

[2/n]decouple quantization implementation from vLLM dependency (#8112)


Co-authored-by: default avatarwalker-ai <yiyun.wyt@antgroup.com>
Co-authored-by: default avatarleoneo <1320612015@qq.com>
parent 4dbf4360
......@@ -3,8 +3,8 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh"
#include "gemm/marlin/marlin.cuh"
#include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \
......
......@@ -18,13 +18,12 @@
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "gptq_marlin/marlin.cuh"
#include "gptq_marlin/marlin_dtypes.cuh"
#include "gemm/marlin/marlin.cuh"
#include "gemm/marlin/marlin_dtypes.cuh"
#include "scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......
......@@ -23,7 +23,6 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif
#include "core/registration.h"
#include "kernel.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......@@ -50,8 +49,7 @@ __global__ void permute_cols_kernel(
int size_m,
int size_k,
int top_k) {};
} // namespace marlin
}
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
......
......@@ -298,6 +298,7 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
......@@ -313,6 +314,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
......
......@@ -224,6 +224,40 @@ void dsv3_router_gemm(torch::Tensor& output, const torch::Tensor& mat_a, const t
void dsv3_fused_a_gemm(torch::Tensor& output, torch::Tensor const& mat_a, torch::Tensor const& mat_b);
torch::Tensor gptq_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight,
torch::Tensor& b_scales,
std::optional<torch::Tensor> const& global_scale_or_none,
std::optional<torch::Tensor> const& b_zeros_or_none,
std::optional<torch::Tensor> const& g_idx_or_none,
std::optional<torch::Tensor> const& perm_or_none,
torch::Tensor& workspace,
sglang::ScalarTypeId const& b_q_type_id,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_shuffle,
int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
/*
* From csrc/moe
*/
......@@ -340,15 +374,6 @@ void scaled_fp4_experts_quant(
torch::Tensor const& input_offset_by_experts,
torch::Tensor const& output_scale_offset_by_experts);
namespace marlin_moe_wna16 {
torch::Tensor
gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits);
} // namespace marlin_moe_wna16
/*
* From csrc/speculative
*/
......
......@@ -44,6 +44,9 @@ from sgl_kernel.gemm import (
dsv3_router_gemm,
fp8_blockwise_scaled_mm,
fp8_scaled_mm,
gptq_gemm,
gptq_marlin_gemm,
gptq_shuffle,
int8_scaled_mm,
qserve_w4a8_per_chn_gemm,
qserve_w4a8_per_group_gemm,
......
......@@ -2,6 +2,7 @@ import functools
from typing import Optional
import torch
from sgl_kernel import silu_and_mul
def get_scalar_type(num_bits: int, has_zp: bool):
......@@ -165,7 +166,7 @@ def fused_marlin_moe(
is_zp_float=False,
)
torch.ops._C.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N))
silu_and_mul(intermediate_cache1.view(-1, 2 * N), intermediate_cache2)
if expert_map is not None:
intermediate_cache3.zero_()
......
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
from sgl_kernel.scalar_type import ScalarType
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
......@@ -353,3 +354,62 @@ def scaled_fp4_experts_quant(
)
output_scales = output_scales.view(torch.float8_e4m3fn)
return output, output_scales
# GPTQ kernels
def gptq_marlin_gemm(
a: torch.Tensor,
c: Optional[torch.Tensor],
b_q_weight: torch.Tensor,
b_scales: torch.Tensor,
global_scale: Optional[torch.Tensor],
b_zeros: Optional[torch.Tensor],
g_idx: Optional[torch.Tensor],
perm: Optional[torch.Tensor],
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
is_k_full: bool = True,
use_atomic_add: bool = False,
use_fp32_reduce: bool = False,
is_zp_float: bool = False,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_gemm(
a,
c,
b_q_weight,
b_scales,
global_scale,
b_zeros,
g_idx,
perm,
workspace,
b_q_type.id,
size_m,
size_n,
size_k,
is_k_full,
use_atomic_add,
use_fp32_reduce,
is_zp_float,
)
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_shuffle: bool,
bit: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
......@@ -7,8 +7,8 @@ def gptq_marlin_repack(
size_k,
size_n,
num_bits,
):
torch.ops.sgl_kernel.gptq_marlin_repack.default(
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_marlin_repack(
b_q_weight,
perm,
size_k,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment