Unverified Commit d62aebfe authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #53 from qwopqwop200/main

support windows 
parents 72f954ce 14d4f8cb
...@@ -5,6 +5,12 @@ import torch.nn as nn ...@@ -5,6 +5,12 @@ import torch.nn as nn
import awq_inference_engine import awq_inference_engine
from torch.nn import functional as F from torch.nn import functional as F
try:
import ft_inference_engine
FT_INSTALLED = True
except:
FT_INSTALLED = False
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore t = torch.arange(end, device=freqs.device) # type: ignore
...@@ -156,7 +162,7 @@ class QuantAttentionFused(nn.Module): ...@@ -156,7 +162,7 @@ class QuantAttentionFused(nn.Module):
xk = self.attention_shapes["xk_slice"](xqkv) xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv) xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1: if seqlen > 1 or not FT_INSTALLED:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"]) xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"]) xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"]) xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
...@@ -177,6 +183,11 @@ class QuantAttentionFused(nn.Module): ...@@ -177,6 +183,11 @@ class QuantAttentionFused(nn.Module):
self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
if seqlen == 1:
xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
keys = xk keys = xk
values = xv values = xv
...@@ -185,7 +196,6 @@ class QuantAttentionFused(nn.Module): ...@@ -185,7 +196,6 @@ class QuantAttentionFused(nn.Module):
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups) values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
past_key_value = (xk, xv) if use_cache else None past_key_value = (xk, xv) if use_cache else None
xq = xq.transpose(1, 2) xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2) keys = keys.transpose(1, 2)
values = values.transpose(1, 2) values = values.transpose(1, 2)
......
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <torch/extension.h> #include <torch/extension.h>
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h" #include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h" #include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h" #include "quantization/gemv_cuda.h"
...@@ -13,8 +12,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) ...@@ -13,8 +12,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel."); m.def("gemmv2_forward_cuda", &gemmv2_forward_cuda, "Quantized v2 GEMM kernel.");
m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel."); m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel.");
m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key");
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
} }
\ No newline at end of file
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("single_query_attention", &single_query_attention, "Attention with a single query",
py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"),
py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0,
py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true);
}
\ No newline at end of file
...@@ -96,10 +96,13 @@ generator_flags = get_generator_flag() ...@@ -96,10 +96,13 @@ generator_flags = get_generator_flag()
arch_flags = get_compute_capabilities() arch_flags = get_compute_capabilities()
if os.name == "nt": if os.name == "nt":
include_arch = os.getenv("INCLUDE_ARCH", "1") == "1"
# Relaxed args on Windows # Relaxed args on Windows
extra_compile_args={ if include_arch:
"nvcc": arch_flags extra_compile_args={"nvcc": arch_flags}
} else:
extra_compile_args={}
else: else:
extra_compile_args={ extra_compile_args={
"cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"], "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17", "-DENABLE_BF16"],
...@@ -123,16 +126,26 @@ extensions = [ ...@@ -123,16 +126,26 @@ extensions = [
CUDAExtension( CUDAExtension(
"awq_inference_engine", "awq_inference_engine",
[ [
"awq_cuda/pybind.cpp", "awq_cuda/pybind_awq.cpp",
"awq_cuda/quantization/gemm_cuda_gen.cu", "awq_cuda/quantization/gemm_cuda_gen.cu",
"awq_cuda/layernorm/layernorm.cu", "awq_cuda/layernorm/layernorm.cu",
"awq_cuda/position_embedding/pos_encoding_kernels.cu", "awq_cuda/position_embedding/pos_encoding_kernels.cu",
"awq_cuda/quantization/gemv_cuda.cu", "awq_cuda/quantization/gemv_cuda.cu"
], extra_compile_args=extra_compile_args
)
]
if os.name != "nt":
extensions.append(
CUDAExtension(
"ft_inference_engine",
[
"awq_cuda/pybind_ft.cpp",
"awq_cuda/attention/ft_attention.cpp", "awq_cuda/attention/ft_attention.cpp",
"awq_cuda/attention/decoder_masked_multihead_attention.cu" "awq_cuda/attention/decoder_masked_multihead_attention.cu"
], extra_compile_args=extra_compile_args ], extra_compile_args=extra_compile_args
) )
] )
additional_setup_kwargs = { additional_setup_kwargs = {
"ext_modules": extensions, "ext_modules": extensions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment