You need to sign in or sign up before continuing.
Unverified Commit 83d1f4b3 authored by Casper's avatar Casper Committed by GitHub
Browse files

New kernels 0222 (#12)

* New kernels 0222

* Test build

* Add to setup

* Fix cc

* Separate pybindings

* Fix casting

* Fix name of pybindings

* Fix pybindings

* Build on tag

* Bump to 0.0.6

* Upgrade to 2.2.1
parent 8714b16c
...@@ -91,7 +91,7 @@ jobs: ...@@ -91,7 +91,7 @@ jobs:
# Install torch # Install torch
$cudaVersion = $env:CUDA_VERSION.Replace('.', '') $cudaVersion = $env:CUDA_VERSION.Replace('.', '')
$cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1) $cudaVersionPytorch = $cudaVersion.Substring(0, $cudaVersion.Length - 1)
if ([int]$cudaVersionPytorch -gt 118) { $pytorchVersion = "torch==2.2.0" } else {$pytorchVersion = "torch==2.0.1"} $pytorchVersion = "torch==2.2.1"
python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch python -m pip install --upgrade --no-cache-dir $pytorchVersion+cu$cudaVersionPytorch --index-url https://download.pytorch.org/whl/cu$cudaVersionPytorch
python -m pip install build setuptools wheel ninja python -m pip install build setuptools wheel ninja
......
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "quantization_new/gemm/gemm_cuda.h"
#include "quantization_new/gemv/gemv_cuda.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("gemm_forward_cuda_prefill", &gemm_forward_cuda_prefill, "New quantized GEMM kernel.");
m.def("gemv_forward_cuda_decode", &gemv_forward_cuda_decode, "New quantized GEMM kernel.");
}
\ No newline at end of file
/*
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#pragma once
__inline__ __device__ void dequantize_s4_to_fp16x2(half2 const &source, uint4 *result)
{
// uint4 result;
uint32_t *h = reinterpret_cast<uint32_t *>(result);
uint32_t const i4s = reinterpret_cast<uint32_t const &>(source);
// First, we extract the i4s and construct an intermediate fp16 number.
static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t BOTTOM_MASK = 0x000f000f;
static constexpr uint32_t TOP_MASK = 0x00f000f0;
static constexpr uint32_t I4s_TO_F16s_MAGIC_NUM = 0x64006400;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const uint32_t top_i4s = i4s >> 8;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[0])
: "r"(i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[1])
: "r"(i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[2])
: "r"(top_i4s), "n"(BOTTOM_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
: "=r"(h[3])
: "r"(top_i4s), "n"(TOP_MASK), "n"(I4s_TO_F16s_MAGIC_NUM), "n"(immLut));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64006400;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static constexpr uint32_t ONE_SIXTEENTH = 0x2c002c00;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static constexpr uint32_t NEG_64 = 0xd400d400;
// Finally, we construct the output numbers.
// Convert elt_01
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[0]) : "r"(h[0]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_23
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[1]) : "r"(h[1]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// Convert elt_45
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(h[2]) : "r"(h[2]), "r"(FP16_TOP_MAGIC_NUM));
// Convert elt_67
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[3]) : "r"(h[3]), "r"(ONE_SIXTEENTH), "r"(NEG_64));
// return result;
}
\ No newline at end of file
This diff is collapsed.
#include <torch/extension.h>
torch::Tensor gemm_forward_cuda_prefill(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scales, torch::Tensor _zeros);
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Implementation of a CTA-wide semaphore for inter-CTA synchronization.
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
// namespace cutlass {
/////////////////////////////////////////////////////////////////////////////////////////////////
/// CTA-wide semaphore for inter-CTA synchronization.
class Semaphore
{
public:
int *lock;
bool wait_thread;
int state;
public:
/// Implements a semaphore to wait for a flag to reach a given value
__host__ __device__ Semaphore(int *lock_, int thread_id) : lock(lock_),
wait_thread(thread_id < 0 || thread_id == 0),
state(-1)
{
}
/// Permit fetching the synchronization mechanism early
__device__ void fetch()
{
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#else
asm volatile("ld.global.cg.b32 %0, [%1];\n" : "=r"(state) : "l"(lock));
#endif
}
}
/// Gets the internal state
__device__ int get_state() const
{
return state;
}
/// Waits until the semaphore is equal to the given value
__device__ void wait(int status = 0)
{
while (__syncthreads_and(state != status))
{
fetch();
}
__syncthreads();
}
/// Updates the lock with the given result
__device__ void release(int status = 0)
{
__syncthreads();
if (wait_thread)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
asm volatile("st.global.release.gpu.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#else
asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status));
#endif
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// } // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
* Modified from NVIDIA [TRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/tree/d37b507f41a87457fe9f10f7459d08f5db235745/cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv)
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include "gemv_cuda.h"
#include "../dequantize.cuh"
#define PACK_FACTOR 8
#define WARP_SIZE 32
#define MEM_ACCESS_SIZE 128
static inline __device__ float to_float(half src)
{
return __half2float(src);
}
static inline __device__ float to_float(float src)
{
return src;
}
static inline __device__ half to_half(float src)
{
return __float2half(src);
}
static inline __device__ half to_half(half src)
{
return src;
}
// Reduce sum within the warp using the tree reduction algorithm.
template <int Num, int WarpSize>
__device__ __forceinline__ static void warp_reduce(half* psum, float (*out_smem)[Num * 4])
{
// kInterleave = 4
float fpsum[Num];
#pragma unroll
for (int i = 0; i < Num; ++i)
{
fpsum[i] = to_float(psum[i]);
}
#pragma unroll
for (int i = 0; i < Num; ++i)
{
// T0 + T1 + T8 + T9 + T16 + T17 + T24 + T25 (kInterleave = 4)
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 16);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 8);
fpsum[i] += __shfl_xor_sync(~0, fpsum[i], 1);
}
__syncthreads();
int warp = threadIdx.x / WarpSize, lane = threadIdx.x % WarpSize;
if (lane == 0 || lane == 2 || lane == 4 || lane == 6)
{
#pragma unroll
for (int i = 0; i < Num; ++i)
{
out_smem[warp][i * 4 + lane / 2] = fpsum[i];
}
}
__syncthreads();
};
__device__ __forceinline__ int make_divisible(int c, int divisor){
return (c + divisor - 1) / divisor;
}
template <int NPerBlock, int Batch, int BlockSize, int GroupSize>
__global__ void gemv_kernel(
const half* inputs, const uint32_t* weight, const half* scales, const half* zeros, half* outputs,
const int IC, const int OC)
{
const int kStride = 64;
const int kElemsPerThread = MEM_ACCESS_SIZE / 4;
const int kThreadsNumPerTile = kStride / kElemsPerThread;
// assert(MEM_ACCESS_SIZE == 128);
static constexpr int kShuffleSize = 32;
static constexpr int kShuffleBasicTile = 2;
static constexpr int kShuffleContinous = 4;
static constexpr int kShuffleStrided = 4;
constexpr int Num = NPerBlock * Batch;
constexpr int kInterleave = 4;
half local_inputs[kElemsPerThread];
uint32_t local_qweights[MEM_ACCESS_SIZE / 32];
half half_weight_buffer[kElemsPerThread];
half dequantized_weight[kElemsPerThread * NPerBlock];
half local_scale[NPerBlock];
half local_scaled_zeros[NPerBlock];
half psum[Num];
for (int i = 0; i < Num; ++i)
psum[i] = to_half(0.f);
extern __shared__ uint8_t shmem[];
float(*out_smem)[Num * kInterleave] = reinterpret_cast<float(*)[Num * kInterleave]>(shmem);
const int blk_row_offset = blockIdx.x * NPerBlock * kInterleave;
const int thd_row_offset = (threadIdx.x / kThreadsNumPerTile) % kInterleave;
const int act_k_offset = threadIdx.x / (kThreadsNumPerTile * kInterleave) * kStride
+ (threadIdx.x % kThreadsNumPerTile) * kElemsPerThread;
const int group_offset = act_k_offset / GroupSize;
// TODO: use make_divisible
const uint32_t* blk_weight_ptr = weight + blk_row_offset * IC / PACK_FACTOR;
const half* scale_ptr = scales + blk_row_offset + thd_row_offset + group_offset * OC;
const half* zeros_ptr = zeros + blk_row_offset + thd_row_offset + group_offset * OC;
const half* inputs_ptr = inputs + act_k_offset;
const int act_forward_step = BlockSize * kElemsPerThread / kInterleave;
const int scale_forward_step = act_forward_step / GroupSize * OC;
// Main loop iteration, each block completes the outputs for several OCs
for (int kk = threadIdx.x * kElemsPerThread; kk < IC * kInterleave; kk += BlockSize * kElemsPerThread)
{
// Load qweight, scales and scaled_zeros
#pragma unroll
for (int idx = 0; idx < NPerBlock; ++idx)
{
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4, 128 bit)
*((float4*)(local_qweights)) =
*((float4*)(blk_weight_ptr + (idx * kInterleave * IC + kk)/ PACK_FACTOR));
local_scale[idx] = *(scale_ptr + idx * kInterleave);
local_scaled_zeros[idx] = *(zeros_ptr + idx * kInterleave);
// Map int4 qweight to fp format
#pragma unroll
for (int i = 0; i < MEM_ACCESS_SIZE / 32; ++i)
{
// Converts 32 bits (8 x int4) to 8 fp16
dequantize_s4_to_fp16x2(*reinterpret_cast<half2 *>(local_qweights + i), reinterpret_cast<uint4 *>(half_weight_buffer + i * PACK_FACTOR));
}
// Dequantize (apply s/z) and shuffle elements to match the weight packing format
#pragma unroll
for (int i = 0; i < kShuffleContinous; ++i)
{
#pragma unroll
for (int j = 0; j < kShuffleStrided; ++j)
{
half2 w =
*reinterpret_cast<half2*>(
half_weight_buffer + (i + j * kShuffleContinous)* kShuffleBasicTile
);
w = __hfma2(w, __half2half2(local_scale[idx]), __half2half2(local_scaled_zeros[idx]));
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 0)
* NPerBlock + idx]
= w.x;
dequantized_weight[((i * kShuffleStrided + j) * kShuffleBasicTile + 1)
* NPerBlock + idx]
= w.y;
}
}
}
#pragma unroll
for (int batch_idx = 0; batch_idx < Batch; ++batch_idx)
{
const half* local_inputs_ptr = inputs_ptr + batch_idx * IC;
#pragma unroll
for (int idx = 0; idx < kElemsPerThread / 8; ++idx)
{
// load activation, 8 halves (128 bits) / step.
*((float4*)(local_inputs + idx * 8)) = *((float4*)(local_inputs_ptr + idx * 8));
}
// Perform the MACs
#pragma unroll
for (int x = 0; x < NPerBlock / 2; ++x)
{
#pragma unroll
for (int y = 0; y < kElemsPerThread; ++y)
{
*reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2)
= __hfma2(*reinterpret_cast<half2*>(dequantized_weight + y * NPerBlock + x * 2),
__half2half2(local_inputs[y]),
*reinterpret_cast<half2*>(psum + batch_idx * NPerBlock + x * 2));
}
}
}
inputs_ptr += act_forward_step;
scale_ptr += scale_forward_step;
zeros_ptr += scale_forward_step;
}
warp_reduce<Num, WARP_SIZE>(psum, out_smem);
// Num * Interleave = batch * NPerBlock * Interleave -> 1 thread_block write back num
for (int i = threadIdx.x; i < Num * kInterleave; i += BlockSize)
{
int batch_idx = i / (NPerBlock * kInterleave);
int oc_idx = i % (NPerBlock * kInterleave);
float acc = 0.f;
for (int j = 0; j < BlockSize / WARP_SIZE; ++j)
{
acc += out_smem[j][i];
}
outputs[batch_idx * OC + blk_row_offset + oc_idx] = to_half(acc);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch::Tensor gemv_forward_cuda_decode(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size)
{
std::vector<int64_t> output_shape = _in_feats.sizes().vec();
output_shape.back() = n;
auto in_feats = reinterpret_cast<half*>(_in_feats.data_ptr<at::Half>());
auto kernel = reinterpret_cast<uint32_t*>(_kernel.data_ptr());
auto zeros = reinterpret_cast<half*>(_zeros.data_ptr<at::Half>());
auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device());
at::Tensor _out_feats = torch::empty(output_shape, options);
half * out_feats = reinterpret_cast<half *>(_out_feats.data_ptr());
static constexpr int N_PER_BLOCK = 2;
static constexpr int K_INTERLEAVE = 4;
static constexpr int BLOCK_SIZE = 256;
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
dim3 num_threads(BLOCK_SIZE);
// if (group_size == 64)
// {
// gemv_kernel_g64<<<num_blocks, num_threads>>>(
// // pointers
// in_feats, kernel, zeros, scaling_factors, out_feats,
// // constants
// num_in_channels, num_out_channels
// );
// }
if (group_size == 128)
{
switch (m)
{
case 1:
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 2:
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 3:
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 4:
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 5:
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 6:
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
case 7:
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
);
break;
default:
throw std::runtime_error("Unsupported batch size for gemv kernel.\n");
}
}
else
{
throw std::runtime_error("Unsupported group size for gemv kernel.\n");
}
return _out_feats;
}
#pragma once
#include <torch/extension.h>
torch::Tensor gemv_forward_cuda_decode(
torch::Tensor _in_feats,
torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
int m,
int n,
int k,
int group_size);
#!/bin/bash #!/bin/bash
# Set variables # Set variables
AWQ_KERNELS_VERSION="0.0.5" AWQ_KERNELS_VERSION="0.0.6"
RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ_kernels/releases/tags/v${AWQ_KERNELS_VERSION}" RELEASE_URL="https://api.github.com/repos/casper-hansen/AutoAWQ_kernels/releases/tags/v${AWQ_KERNELS_VERSION}"
# Create a directory to download the wheels # Create a directory to download the wheels
......
...@@ -7,7 +7,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension ...@@ -7,7 +7,7 @@ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
os.environ["CC"] = "g++" os.environ["CC"] = "g++"
os.environ["CXX"] = "g++" os.environ["CXX"] = "g++"
AUTOAWQ_KERNELS_VERSION = "0.0.5" AUTOAWQ_KERNELS_VERSION = "0.0.6"
PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1" PYPI_BUILD = os.getenv("PYPI_BUILD", "0") == "1"
CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda CUDA_VERSION = os.getenv("CUDA_VERSION", None) or torch.version.cuda
ROCM_VERSION = os.environ.get("ROCM_VERSION", None) or torch.version.hip ROCM_VERSION = os.environ.get("ROCM_VERSION", None) or torch.version.hip
...@@ -89,7 +89,9 @@ def get_generator_flag(): ...@@ -89,7 +89,9 @@ def get_generator_flag():
return generator_flag return generator_flag
def get_compute_capabilities(): def get_compute_capabilities(
compute_capabilities={75, 80, 86, 89, 90}
):
capability_flags = [] capability_flags = []
if CUDA_VERSION: if CUDA_VERSION:
...@@ -103,7 +105,6 @@ def get_compute_capabilities(): ...@@ -103,7 +105,6 @@ def get_compute_capabilities():
) )
# Figure out compute capability # Figure out compute capability
compute_capabilities = {75, 80, 86, 89, 90}
for cap in compute_capabilities: for cap in compute_capabilities:
capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"] capability_flags += ["-gencode", f"arch=compute_{cap},code=sm_{cap}"]
...@@ -180,6 +181,22 @@ if CUDA_VERSION: ...@@ -180,6 +181,22 @@ if CUDA_VERSION:
) )
) )
# only compatible with ampere
arch_flags = get_compute_capabilities({80, 86, 89, 90})
extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags)
extensions.append(
CUDAExtension(
"awq_v2_ext",
[
"awq_ext/pybind_awq_v2.cpp",
"awq_ext/quantization_new/gemv/gemv_cuda.cu",
"awq_ext/quantization_new/gemm/gemm_cuda.cu",
],
extra_compile_args=extra_compile_args_v2,
)
)
extensions.append( extensions.append(
CUDAExtension( CUDAExtension(
"exl_ext", "exl_ext",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment