Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e1884728
Unverified
Commit
e1884728
authored
Sep 15, 2023
by
qwopqwop200
Committed by
GitHub
Sep 15, 2023
Browse files
suppport windows
parent
a5772f67
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
193 additions
and
168 deletions
+193
-168
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+3
-1
awq_cuda/pybind_linux.cpp
awq_cuda/pybind_linux.cpp
+19
-19
awq_cuda/pybind_windows.cpp
awq_cuda/pybind_windows.cpp
+14
-0
setup.py
setup.py
+157
-148
No files found.
awq/modules/fused/attn.py
View file @
e1884728
...
@@ -5,6 +5,8 @@ import torch.nn as nn
...
@@ -5,6 +5,8 @@ import torch.nn as nn
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
have_single_query_attention
=
hasattr
(
awq_inference_engine
,
'single_query_attention'
)
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
...
@@ -184,7 +186,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -184,7 +186,7 @@ class QuantAttentionFused(nn.Module):
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
if
seqlen
>
1
:
if
seqlen
>
1
and
have_single_query_attention
:
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
...
...
awq_cuda/pybind_linux.cpp
View file @
e1884728
awq_cuda/pybind_windows.cpp
0 → 100644
View file @
e1884728
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
}
\ No newline at end of file
setup.py
View file @
e1884728
...
@@ -97,9 +97,18 @@ arch_flags = get_compute_capabilities()
...
@@ -97,9 +97,18 @@ arch_flags = get_compute_capabilities()
if
os
.
name
==
"nt"
:
if
os
.
name
==
"nt"
:
# Relaxed args on Windows
# Relaxed args on Windows
extra_compile_args
=
{
extensions
=
[
"nvcc"
:
arch_flags
CUDAExtension
(
}
"awq_inference_engine"
,
[
"awq_cuda/pybind_windows.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
]
)
]
else
:
else
:
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
...
@@ -119,11 +128,11 @@ else:
...
@@ -119,11 +128,11 @@ else:
]
+
arch_flags
+
generator_flags
]
+
arch_flags
+
generator_flags
}
}
extensions
=
[
extensions
=
[
CUDAExtension
(
CUDAExtension
(
"awq_inference_engine"
,
"awq_inference_engine"
,
[
[
"awq_cuda/pybind.cpp"
,
"awq_cuda/pybind
_linux
.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
...
@@ -132,7 +141,7 @@ extensions = [
...
@@ -132,7 +141,7 @@ extensions = [
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
],
extra_compile_args
=
extra_compile_args
],
extra_compile_args
=
extra_compile_args
)
)
]
]
additional_setup_kwargs
=
{
additional_setup_kwargs
=
{
"ext_modules"
:
extensions
,
"ext_modules"
:
extensions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment