Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d62aebfe
Unverified
Commit
d62aebfe
authored
Sep 21, 2023
by
Casper
Committed by
GitHub
Sep 21, 2023
Browse files
Merge pull request #53 from qwopqwop200/main
support windows
parents
72f954ce
14d4f8cb
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
186 additions
and
157 deletions
+186
-157
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+13
-3
awq_cuda/pybind_awq.cpp
awq_cuda/pybind_awq.cpp
+0
-5
awq_cuda/pybind_ft.cpp
awq_cuda/pybind_ft.cpp
+11
-0
setup.py
setup.py
+162
-149
No files found.
awq/modules/fused/attn.py
View file @
d62aebfe
...
@@ -5,6 +5,12 @@ import torch.nn as nn
...
@@ -5,6 +5,12 @@ import torch.nn as nn
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
try
:
import
ft_inference_engine
FT_INSTALLED
=
True
except
:
FT_INSTALLED
=
False
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
...
@@ -156,7 +162,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -156,7 +162,7 @@ class QuantAttentionFused(nn.Module):
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
if
seqlen
>
1
:
if
seqlen
>
1
or
not
FT_INSTALLED
:
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
...
@@ -177,6 +183,11 @@ class QuantAttentionFused(nn.Module):
...
@@ -177,6 +183,11 @@ class QuantAttentionFused(nn.Module):
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
values_store
self
.
cache_v
[:
bsz
,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
values_store
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
keys_store
self
.
cache_k
[:
bsz
,
:,
:,
self
.
start_pos
:
self
.
start_pos
+
seqlen
,
:]
=
keys_store
if
seqlen
==
1
:
xv
=
self
.
cache_v
[:
bsz
,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
1
,
2
).
contiguous
()
xk
=
self
.
cache_k
[:
bsz
,
:,
:,
:
self
.
start_pos
+
seqlen
,
:].
transpose
(
2
,
3
).
contiguous
()
xk
=
xk
.
reshape
(
xk
.
shape
[:
-
2
]
+
(
self
.
head_dim
,)).
transpose
(
1
,
2
).
contiguous
()
keys
=
xk
keys
=
xk
values
=
xv
values
=
xv
...
@@ -185,7 +196,6 @@ class QuantAttentionFused(nn.Module):
...
@@ -185,7 +196,6 @@ class QuantAttentionFused(nn.Module):
values
=
torch
.
repeat_interleave
(
values
,
dim
=
2
,
repeats
=
self
.
n_kv_groups
)
values
=
torch
.
repeat_interleave
(
values
,
dim
=
2
,
repeats
=
self
.
n_kv_groups
)
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
past_key_value
=
(
xk
,
xv
)
if
use_cache
else
None
xq
=
xq
.
transpose
(
1
,
2
)
xq
=
xq
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
keys
=
keys
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
values
=
values
.
transpose
(
1
,
2
)
...
...
awq_cuda/pybind.cpp
→
awq_cuda/pybind
_awq
.cpp
View file @
d62aebfe
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h"
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "quantization/gemv_cuda.h"
...
@@ -13,8 +12,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
...
@@ -13,8 +12,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m
.
def
(
"gemmv2_forward_cuda"
,
&
gemmv2_forward_cuda
,
"Quantized v2 GEMM kernel."
);
m
.
def
(
"gemmv2_forward_cuda"
,
&
gemmv2_forward_cuda
,
"Quantized v2 GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
}
\ No newline at end of file
awq_cuda/pybind_ft.cpp
0 → 100644
View file @
d62aebfe
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
\ No newline at end of file
setup.py
View file @
d62aebfe
...
@@ -96,10 +96,13 @@ generator_flags = get_generator_flag()
...
@@ -96,10 +96,13 @@ generator_flags = get_generator_flag()
arch_flags
=
get_compute_capabilities
()
arch_flags
=
get_compute_capabilities
()
if
os
.
name
==
"nt"
:
if
os
.
name
==
"nt"
:
include_arch
=
os
.
getenv
(
"INCLUDE_ARCH"
,
"1"
)
==
"1"
# Relaxed args on Windows
# Relaxed args on Windows
extra_compile_args
=
{
if
include_arch
:
"nvcc"
:
arch_flags
extra_compile_args
=
{
"nvcc"
:
arch_flags
}
}
else
:
extra_compile_args
=
{}
else
:
else
:
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
...
@@ -123,16 +126,26 @@ extensions = [
...
@@ -123,16 +126,26 @@ extensions = [
CUDAExtension
(
CUDAExtension
(
"awq_inference_engine"
,
"awq_inference_engine"
,
[
[
"awq_cuda/pybind.cpp"
,
"awq_cuda/pybind
_awq
.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
],
extra_compile_args
=
extra_compile_args
)
]
if
os
.
name
!=
"nt"
:
extensions
.
append
(
CUDAExtension
(
"ft_inference_engine"
,
[
"awq_cuda/pybind_ft.cpp"
,
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
],
extra_compile_args
=
extra_compile_args
],
extra_compile_args
=
extra_compile_args
)
)
]
)
additional_setup_kwargs
=
{
additional_setup_kwargs
=
{
"ext_modules"
:
extensions
,
"ext_modules"
:
extensions
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment