Unverified Commit e1884728 authored by qwopqwop200's avatar qwopqwop200 Committed by GitHub
Browse files

suppport windows

parent a5772f67
......@@ -5,6 +5,8 @@ import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
have_single_query_attention = hasattr(awq_inference_engine, 'single_query_attention')
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
......@@ -184,7 +186,7 @@ class QuantAttentionFused(nn.Module):
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1:
if seqlen > 1 and have_single_query_attention:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
......
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel");
m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
}
\ No newline at end of file
......@@ -97,9 +97,18 @@ arch_flags = get_compute_capabilities()
if os.name == "nt":
# Relaxed args on Windows
extra_compile_args={
"nvcc": arch_flags
}
extensions = [
CUDAExtension(
"awq_inference_engine",
[
"awq_cuda/pybind_windows.cpp",
"awq_cuda/quantization/gemm_cuda_gen.cu",
"awq_cuda/layernorm/layernorm.cu",
"awq_cuda/position_embedding/pos_encoding_kernels.cu",
"awq_cuda/quantization/gemv_cuda.cu",
]
)
]
else:
extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
......@@ -119,11 +128,11 @@ else:
] + arch_flags + generator_flags
}
extensions = [
extensions = [
CUDAExtension(
"awq_inference_engine",
[
"awq_cuda/pybind.cpp",
"awq_cuda/pybind_linux.cpp",
"awq_cuda/quantization/gemm_cuda_gen.cu",
"awq_cuda/layernorm/layernorm.cu",
"awq_cuda/position_embedding/pos_encoding_kernels.cu",
......@@ -132,7 +141,7 @@ extensions = [
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
], extra_compile_args=extra_compile_args
)
]
]
additional_setup_kwargs = {
"ext_modules": extensions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment