Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e3b8d9c4
Commit
e3b8d9c4
authored
Sep 20, 2023
by
Casper
Browse files
Update setup
parent
6dba61b8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
36 deletions
+54
-36
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+7
-2
awq_cuda/pybind_awq.cpp
awq_cuda/pybind_awq.cpp
+14
-13
awq_cuda/pybind_ft.cpp
awq_cuda/pybind_ft.cpp
+11
-0
setup.py
setup.py
+22
-21
No files found.
awq/modules/fused/attn.py
View file @
e3b8d9c4
...
...
@@ -5,7 +5,11 @@ import torch.nn as nn
import
awq_inference_engine
from
torch.nn
import
functional
as
F
have_single_query_attention
=
hasattr
(
awq_inference_engine
,
'single_query_attention'
)
try
:
import
ft_inference_engine
FT_INSTALLED
=
True
except
:
FT_INSTALLED
=
False
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
...
...
@@ -158,7 +162,7 @@ class QuantAttentionFused(nn.Module):
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
if
seqlen
>
1
or
not
(
have_single_query_attention
)
:
if
seqlen
>
1
or
not
FT_INSTALLED
:
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
...
...
@@ -183,6 +187,7 @@ class QuantAttentionFused(nn.Module):
xv
=
self
.
cache_v
[:
bsz
,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
1
,
2
).
contiguous
()
xk
=
self
.
cache_k
[:
bsz
,
:,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
2
,
3
).
contiguous
()
xk
=
xk
.
reshape
(
xk
.
shape
[:
-
2
]
+
(
self
.
head_dim
,)).
transpose
(
1
,
2
).
contiguous
()
keys
=
xk
values
=
xv
...
...
awq_cuda/pybind_
windows
.cpp
→
awq_cuda/pybind_
awq
.cpp
View file @
e3b8d9c4
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemmv2_forward_cuda"
,
&
gemmv2_forward_cuda
,
"Quantized v2 GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
}
\ No newline at end of file
awq_cuda/pybind_ft.cpp
0 → 100644
View file @
e3b8d9c4
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
\ No newline at end of file
setup.py
View file @
e3b8d9c4
...
...
@@ -97,18 +97,9 @@ arch_flags = get_compute_capabilities()
if
os
.
name
==
"nt"
:
# Relaxed args on Windows
extensions
=
[
CUDAExtension
(
"awq_inference_engine"
,
[
"awq_cuda/pybind_windows.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
]
)
]
extra_compile_args
=
{
"nvcc"
:
arch_flags
}
else
:
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
...
...
@@ -127,21 +118,31 @@ else:
"--use_fast_math"
,
]
+
arch_flags
+
generator_flags
}
extensions
=
[
extensions
=
[
CUDAExtension
(
"awq_inference_engine"
,
[
"awq_cuda/pybind_awq.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
],
extra_compile_args
=
extra_compile_args
)
]
if
os
.
name
!=
"nt"
:
extensions
.
append
(
CUDAExtension
(
"
awq
_inference_engine"
,
"
ft
_inference_engine"
,
[
"awq_cuda/pybind_linux.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
"awq_cuda/pybind_ft.cpp"
,
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
],
extra_compile_args
=
extra_compile_args
)
]
)
additional_setup_kwargs
=
{
"ext_modules"
:
extensions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment