Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
a11c313a
Commit
a11c313a
authored
Sep 08, 2023
by
Casper Hansen
Browse files
Update kernel, remove unused code
parent
fbfa9d82
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
73 additions
and
382 deletions
+73
-382
awq/models/llama.py
awq/models/llama.py
+7
-57
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+12
-221
awq_cuda/position_embedding/pos_encoding.h
awq_cuda/position_embedding/pos_encoding.h
+2
-3
awq_cuda/position_embedding/pos_encoding_kernels.cu
awq_cuda/position_embedding/pos_encoding_kernels.cu
+51
-100
awq_cuda/pybind.cpp
awq_cuda/pybind.cpp
+1
-1
No files found.
awq/models/llama.py
View file @
a11c313a
...
...
@@ -70,7 +70,7 @@ from typing import List, Tuple, Union
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused.mlp
import
QuantLlamaMLP
from
awq.modules.fused.norm
import
FTLlamaRMSNorm
from
awq.modules.fused.attn
import
QuantLlamaAttention
,
QuantLlamaAttentionFused
from
awq.modules.fused.attn
import
QuantLlamaAttentionFused
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
...
...
@@ -96,16 +96,7 @@ class LlamaFuser:
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
Union
[
WQLinear_GEMM
,
WQLinear_GEMV
]
=
self
.
_fuse_qkv2
(
module
)
# attn = QuantLlamaAttention(
# module.hidden_size,
# module.num_heads,
# module.num_key_value_heads,
# qkv_layer,
# module.o_proj,
# next(iter(qkv_layer.state_dict().values())).device,
# self.model.config.max_new_tokens
# )
qkv_layer
:
Union
[
WQLinear_GEMM
,
WQLinear_GEMV
]
=
self
.
_fuse_qkv
(
module
)
attn
=
QuantLlamaAttentionFused
(
module
.
hidden_size
,
module
.
num_heads
,
...
...
@@ -116,21 +107,9 @@ class LlamaFuser:
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
_fuse_qkv2
(
self
,
module
:
LlamaAttention
):
q_proj
=
module
.
q_proj
k_proj
=
module
.
k_proj
v_proj
=
module
.
v_proj
qweights
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
0
)
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
0
)
# g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0)
g_idx
=
None
bias
=
(
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
)
def
_fuse_qkv
(
self
,
module
:
LlamaAttention
):
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
qkv_layer
=
WQLinear_GEMV
(
q_proj
.
w_bit
,
...
...
@@ -140,39 +119,10 @@ class LlamaFuser:
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
,
)
qkv_layer
.
qweight
=
qweights
qkv_layer
.
qzeros
=
qzeros
qkv_layer
.
scales
=
scales
qkv_layer
.
bias
=
bias
qkv_layer
.
split_k_iters
=
q_proj
.
split_k_iters
return
qkv_layer
def
_fuse_qkv
(
self
,
module
:
LlamaAttention
):
# get qkv and bias
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
# create module
if
self
.
quant_config
[
"version"
]
==
'GEMM'
:
qkv_module
=
WQLinear_GEMM
elif
self
.
quant_config
[
"version"
]
==
'GEMV'
:
qkv_module
=
WQLinear_GEMV
qkv_layer
=
qkv_module
(
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
next
(
iter
(
module
.
state_dict
().
values
())).
device
)
# replace buffers with real weights
qkv_layer
.
qweights
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qkv_layer
.
qweight
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
0
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
0
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
0
)
qkv_layer
.
bias
=
bias
qkv_layer
.
split_k_iters
=
q_proj
.
split_k_iters
...
...
awq/modules/fused/attn.py
View file @
a11c313a
...
...
@@ -4,74 +4,6 @@ import torch.nn as nn
import
awq_inference_engine
from
torch.nn
import
functional
as
F
try
:
from
flash_attn
import
flash_attn_func
FLASH_INSTALLED
=
True
except
:
FLASH_INSTALLED
=
False
class
QuantLlamaRotary
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
4096
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
,
is_neox
=
True
,
num_heads
=
None
,
num_kv_heads
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
is_neox
=
is_neox
self
.
num_heads
=
num_heads
self
.
num_kv_heads
=
num_kv_heads
# create cache
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
)
/
dim
))
t
=
torch
.
arange
(
max_position_embeddings
,
device
=
device
).
float
()
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
.
float
())
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
).
to
(
torch
.
get_default_dtype
())
# Embedding size: [max_position, rotary_dim]
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
self
,
qkv_states
:
torch
.
Tensor
,
position_ids
:
torch
.
Tensor
,
batch_size
:
int
,
q_len
:
int
):
# get qkv
query
,
key
,
value
=
qkv_states
.
chunk
(
chunks
=
3
,
dim
=-
1
)
del
qkv_states
# [num_tokens, num_heads * head_size]
query_batch_size
,
query_len
,
_
=
query
.
shape
query
=
query
.
view
(
query_len
*
query_batch_size
,
self
.
num_heads
*
self
.
dim
)
# [num_tokens, num_kv_heads * head_size]
key_batch_size
,
key_len
,
_
=
key
.
shape
key
=
key
.
view
(
key_len
*
key_batch_size
,
self
.
num_kv_heads
*
self
.
dim
)
# [num_tokens]
positions
=
position_ids
.
view
(
-
1
).
to
(
query
.
device
)
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
dim
,
self
.
cos_sin_cache
,
self
.
is_neox
)
query
=
query
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
batch_size
,
q_len
,
self
.
num_heads
,
self
.
dim
).
transpose
(
1
,
2
)
return
query
,
key
,
value
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
...
@@ -120,139 +52,16 @@ class QuantLlamaRotaryEmbedding(nn.Module):
# print(positions.shape, query.shape, key.shape, self.cos_sin_cache.shape)
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding
(
awq_inference_engine
.
rotary_embedding
_neox
(
positions
,
query
,
key
,
self
.
dim
,
self
.
cos_sin_cache
,
True
self
.
cos_sin_cache
)
return
query
,
key
class
TorchAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
use_flash
=
False
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
use_flash
=
use_flash
def
forward
(
self
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
use_cache
:
bool
,
past_key_value
:
torch
.
Tensor
,
batch_size
:
int
,
q_len
:
int
):
is_causal
=
past_key_value
is
None
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
value
=
value
.
to
(
key
.
device
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key
=
torch
.
cat
([
past_key_value
[
0
],
key
],
dim
=
2
)
value
=
torch
.
cat
([
past_key_value
[
1
],
value
],
dim
=
2
)
if
use_cache
:
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key
=
key
.
contiguous
()
value
=
value
.
contiguous
()
query
=
query
.
contiguous
()
past_key_value
=
(
key
,
value
)
if
use_cache
else
None
if
self
.
use_flash
and
FLASH_INSTALLED
:
query
=
query
.
transpose
(
1
,
2
)
key
=
key
.
transpose
(
1
,
2
)
value
=
value
.
transpose
(
1
,
2
)
attn_output
=
flash_attn_func
(
query
,
key
,
value
,
causal
=
is_causal
)
else
:
attn_output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
is_causal
)
del
query
,
key
,
value
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
batch_size
,
q_len
,
self
.
hidden_size
)
return
attn_output
,
past_key_value
class
QuantLlamaAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
hidden_size
,
num_heads
,
num_kv_heads
,
qkv_proj
,
o_proj
,
dev
,
max_new_tokens
):
super
().
__init__
()
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
attn
=
TorchAttention
(
hidden_size
)
self
.
rotary_emb
=
QuantLlamaRotary
(
dim
=
hidden_size
//
num_heads
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
,
is_neox
=
True
,
num_heads
=
num_heads
,
num_kv_heads
=
num_heads
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
"""Input shape: Batch x Time x Channel"""
batch_size
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
query
,
key
,
value
=
self
.
rotary_emb
(
qkv_states
,
position_ids
,
batch_size
,
q_len
)
attn_output
,
past_key_value
=
self
.
attn
(
query
,
key
,
value
,
use_cache
,
past_key_value
,
batch_size
,
q_len
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
freqs
=
torch
.
outer
(
t
,
freqs
).
float
()
# type: ignore
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64
return
freqs_cis
def
reshape_for_broadcast
(
freqs_cis
:
torch
.
Tensor
,
x
:
torch
.
Tensor
):
ndim
=
x
.
ndim
assert
0
<=
1
<
ndim
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
])
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
):
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
freqs_cis
=
reshape_for_broadcast
(
freqs_cis
,
xq_
)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
class
QuantLlamaAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_position_embeddings
):
super
().
__init__
()
...
...
@@ -263,42 +72,24 @@ class QuantLlamaAttentionFused(nn.Module):
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
self
.
freqs_cis
=
precompute_freqs_cis
(
self
.
head_dim
,
max_position_embeddings
*
2
,
)
# following fastertransformer definition
self
.
cache_v
=
(
torch
.
zeros
(
(
1
,
self
.
n_local_heads
,
max_position_embeddings
,
self
.
head_dim
,
)
)
.
to
(
dev
)
.
half
()
)
# added to half
(
1
,
self
.
n_local_heads
,
max_position_embeddings
,
self
.
head_dim
,
)
).
to
(
dev
).
half
()
)
# 8: pack 8 fp16 in FT, if fp32 then use 4
self
.
cache_k
=
(
torch
.
zeros
(
(
1
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_position_embeddings
,
8
,
)
)
.
to
(
dev
)
.
half
()
)
# added to half
(
1
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_position_embeddings
,
8
,
)
).
to
(
dev
).
half
()
)
# dummy
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
hidden_size
//
num_heads
,
max_position_embeddings
=
max_position_embeddings
,
base
=
10000
,
device
=
dev
dim
=
hidden_size
//
num_heads
,
max_position_embeddings
=
max_position_embeddings
,
device
=
dev
)
def
forward
(
...
...
awq_cuda/position_embedding/pos_encoding.h
View file @
a11c313a
#pragma once
#include <torch/extension.h>
void
rotary_embedding
(
void
rotary_embedding
_neox
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
\ No newline at end of file
torch
::
Tensor
&
cos_sin_cache
);
\ No newline at end of file
awq_cuda/position_embedding/pos_encoding_kernels.cu
View file @
a11c313a
...
...
@@ -9,56 +9,15 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
// GPT-NeoX style rotary embedding.
x_index
=
rot_offset
;
y_index
=
embed_dim
+
rot_offset
;
cos
=
__ldg
(
cos_ptr
+
x_index
);
sin
=
__ldg
(
sin_ptr
+
x_index
);
}
else
{
// GPT-J style rotary embedding.
x_index
=
2
*
rot_offset
;
y_index
=
2
*
rot_offset
+
1
;
cos
=
__ldg
(
cos_ptr
+
x_index
/
2
);
sin
=
__ldg
(
sin_ptr
+
x_index
/
2
);
}
const
scalar_t
x
=
arr
[
x_index
];
const
scalar_t
y
=
arr
[
y_index
];
arr
[
x_index
]
=
x
*
cos
-
y
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_
kv_
heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
query_stride
,
const
int
key_stride
,
const
int
stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
...
...
@@ -66,72 +25,64 @@ __global__ void rotary_embedding_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
}
}
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
void
rotary_embedding_neox
(
torch
::
Tensor
&
positions
,
// [b, num_tokens]
torch
::
Tensor
&
query
,
// [b, num_tokens, 1, num_heads, head_size]
torch
::
Tensor
&
key
,
// [b, num_tokens, 1, num_heads, head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
);
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
{
int
num_tokens
=
query
.
size
(
0
)
*
query
.
size
(
1
)
;
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
query_stride
=
query
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
0
);
int
num_heads
=
query
.
size
(
-
2
);
int
stride
=
num_heads
*
head_size
;
// TORCH_CHECK(stride == key.stride(0));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
"rotary_embedding"
,
"rotary_embedding
_neox
"
,
[
&
]
{
if
(
is_neox
)
{
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
stride
,
num_heads
,
head_size
);
});
}
\ No newline at end of file
}
awq_cuda/pybind.cpp
View file @
a11c313a
...
...
@@ -11,7 +11,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply rotary embedding to query and key"
);
m
.
def
(
"rotary_embedding
_neox
"
,
&
rotary_embedding
_neox
,
"Apply
GPT-NeoX style
rotary embedding to query and key"
);
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment