Commit 0de4f1dc authored by zhuwenwen's avatar zhuwenwen
Browse files

add int8

parent b9e12416
......@@ -168,7 +168,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
# "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
......
......@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
#endif
// void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
// float scale);
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
float scale);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table);
......
......@@ -67,8 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size.");
// ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
// "Compute int8 quantized tensor for given scaling factor");
ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
"Compute int8 quantized tensor for given scaling factor");
// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
......
......@@ -264,21 +264,21 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# int8
# def static_scaled_int8_quant(input: torch.Tensor,
# scale: float) -> torch.Tensor:
# """
# Quantize the input tensor to int8 and return the quantized tensor.
# Args:
# input: The input tensor to be quantized to int8.
# scale: Scaling factor for the int8 quantization.
# Returns:
# torch.Tensor: Output tensor in int8.
# """
# q = torch.empty_like(input, dtype=torch.int8)
# vllm_ops.static_scaled_int8_quant(q, input, scale)
# return q
def static_scaled_int8_quant(input: torch.Tensor,
scale: float) -> torch.Tensor:
"""
Quantize the input tensor to int8 and return the quantized tensor.
Args:
input: The input tensor to be quantized to int8.
scale: Scaling factor for the int8 quantization.
Returns:
torch.Tensor: Output tensor in int8.
"""
q = torch.empty_like(input, dtype=torch.int8)
vllm_ops.static_scaled_int8_quant(q, input, scale)
return q
# moe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment