"...ssh:/git@developer.sourcefind.cn:2222/OpenDAS/dynamo.git" did not exist on "a9c0e0c7fa7adfaa8276227dbd020a4e919da46b"
Commit 0de4f1dc authored by zhuwenwen's avatar zhuwenwen
Browse files

add int8

parent b9e12416
...@@ -168,7 +168,7 @@ set(VLLM_EXT_SRC ...@@ -168,7 +168,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu" "csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/gptq/q_gemm.cu"
# "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
# "csrc/quantization/fp8/common.cu" # "csrc/quantization/fp8/common.cu"
"csrc/cuda_utils_kernels.cu" "csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu" "csrc/moe_align_block_size_kernels.cu"
......
...@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a, ...@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
#endif #endif
// void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input, void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
// float scale); float scale);
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);
......
...@@ -67,8 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -67,8 +67,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Aligning the number of tokens to be processed by each expert such " "Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size."); "that it is divisible by the block size.");
// ops.def("static_scaled_int8_quant", &static_scaled_int8_quant, ops.def("static_scaled_int8_quant", &static_scaled_int8_quant,
// "Compute int8 quantized tensor for given scaling factor"); "Compute int8 quantized tensor for given scaling factor");
// Cache ops // Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
......
...@@ -264,21 +264,21 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -264,21 +264,21 @@ def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# int8 # int8
# def static_scaled_int8_quant(input: torch.Tensor, def static_scaled_int8_quant(input: torch.Tensor,
# scale: float) -> torch.Tensor: scale: float) -> torch.Tensor:
# """ """
# Quantize the input tensor to int8 and return the quantized tensor. Quantize the input tensor to int8 and return the quantized tensor.
# Args: Args:
# input: The input tensor to be quantized to int8. input: The input tensor to be quantized to int8.
# scale: Scaling factor for the int8 quantization. scale: Scaling factor for the int8 quantization.
# Returns: Returns:
# torch.Tensor: Output tensor in int8. torch.Tensor: Output tensor in int8.
# """ """
# q = torch.empty_like(input, dtype=torch.int8) q = torch.empty_like(input, dtype=torch.int8)
# vllm_ops.static_scaled_int8_quant(q, input, scale) vllm_ops.static_scaled_int8_quant(q, input, scale)
# return q return q
# moe # moe
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment