Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
198ba2fb
Unverified
Commit
198ba2fb
authored
Sep 06, 2023
by
Casper
Committed by
GitHub
Sep 06, 2023
Browse files
Merge pull request #28 from casper-hansen/fix_rotary_emb
[BUG] Fix illegal memory access + Quantized Multi-GPU support
parents
85430ddc
562c0d52
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
211 additions
and
112 deletions
+211
-112
awq/entry.py
awq/entry.py
+23
-25
awq/models/base.py
awq/models/base.py
+12
-7
awq/models/llama.py
awq/models/llama.py
+4
-3
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+65
-22
awq_cuda/position_embedding/pos_encoding.h
awq_cuda/position_embedding/pos_encoding.h
+3
-2
awq_cuda/position_embedding/pos_encoding_kernels.cu
awq_cuda/position_embedding/pos_encoding_kernels.cu
+100
-51
awq_cuda/pybind.cpp
awq_cuda/pybind.cpp
+1
-1
setup.py
setup.py
+3
-1
No files found.
awq/entry.py
View file @
198ba2fb
...
...
@@ -3,11 +3,11 @@ import time
import
torch
import
argparse
from
lm_eval
import
evaluator
from
transformers
import
AutoTokenizer
from
awq
import
AutoAWQForCausalLM
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
transformers
import
AutoTokenizer
,
GenerationConfig
def
load_search_result_into_memory
(
model
,
search_path
):
...
...
@@ -80,22 +80,6 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
out
=
func
()
return
out
,
time
.
time
()
-
start
def
_generate
(
model
,
model_out
,
n_generate
,
batch_size
):
past_key_values
=
model_out
.
past_key_values
for
i
in
range
(
n_generate
):
logits
=
model_out
.
logits
[:,
-
1
,
:]
new_tokens
=
[]
for
batch_index
in
range
(
batch_size
):
probs
=
torch
.
softmax
(
logits
[
batch_index
],
dim
=-
1
)
token
=
torch
.
multinomial
(
probs
,
num_samples
=
1
)
new_tokens
.
append
(
token
)
tokens
=
torch
.
as_tensor
(
new_tokens
,
device
=
device
).
unsqueeze
(
-
1
)
model_out
=
model
(
tokens
,
use_cache
=
True
,
past_key_values
=
past_key_values
)
def
_warmup
(
device
:
str
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
torch
.
mm
(
warm_up
,
warm_up
)
...
...
@@ -114,19 +98,36 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
batch_size
,
n_context
)).
cuda
()
# Context stage
model_out
,
context_time
=
_timer
(
lambda
:
model
(
ids
,
use_cache
=
True
))
_
,
context_time
=
_timer
(
lambda
:
model
.
generate
(
ids
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
0
,
min_new_tokens
=
0
,
use_cache
=
True
)
))
# Generation stage
_
,
generation_time
=
_timer
(
lambda
:
_generate
(
model
,
model_out
,
n_generate
,
batch_size
))
_
,
generation_time
=
_timer
(
lambda
:
model
.
generate
(
ids
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
n_context
,
min_new_tokens
=
n_context
,
forced_eos_token_id
=-
100
,
pad_token_id
=
tokenizer
.
pad_token_id
,
eos_token_id
=-
100
,
use_cache
=
True
)
))
# Prints
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
context_tokens_per_second
=
n_context
/
context_time
*
batch_size
context_ms_per_token
=
(
context_time
*
1000
)
/
n_context
*
batch_size
context_ms_per_token
=
(
context_time
*
1000
)
/
n_context
/
batch_size
inference_tokens_per_second
=
n_generate
/
generation_time
*
batch_size
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
*
batch_size
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
/
batch_size
print
(
f
"[=
=====
] Model summary:
{
model_path
}
[=
=====
]"
)
print
(
f
"[=] Model summary:
{
model_path
}
[=]"
)
print
(
f
"[*] Load time:
{
load_time
:.
2
f
}
seconds"
)
print
(
f
"[*] Context speed:
{
context_tokens_per_second
:.
2
f
}
tokens/second (
{
context_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] Generation speed:
{
inference_tokens_per_second
:.
2
f
}
tokens/second (
{
inference_ms_per_token
:.
2
f
}
ms/token)"
)
...
...
@@ -185,9 +186,6 @@ if __name__ == '__main__':
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
if
args
.
batch_size
>
1
and
not
args
.
disable_fused_layers
:
raise
Exception
(
'Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).'
)
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
,
args
.
batch_size
,
args
.
disable_fused_layers
)
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
awq/models/base.py
View file @
198ba2fb
...
...
@@ -297,21 +297,26 @@ class BaseAWQForCausalLM(nn.Module):
model
.
tie_weights
()
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
self
.
layer_type
],
dtype
=
torch_dtype
)
# Load model weights
if
is_quantized
:
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device_map
,
no_split_module_classes
=
[
self
.
layer_type
]
)
if
fuse_layers
:
self
.
fuse_layers
(
model
)
else
:
# If not quantized, must load with AutoModelForCausalLM
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
self
.
layer_type
],
dtype
=
torch_dtype
)
del
model
# Load model weights
...
...
awq/models/llama.py
View file @
198ba2fb
...
...
@@ -99,9 +99,10 @@ class LlamaFuser:
attn
=
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_key_value_heads
,
qkv_layer
,
module
.
o_proj
,
qkv_layer
.
qweight
.
device
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
()))
.
device
,
self
.
model
.
config
.
max_new_tokens
)
set_module_name
(
self
.
model
,
name
,
attn
)
...
...
@@ -110,7 +111,7 @@ class LlamaFuser:
# get qkv and bias
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
# create module
qkv_layer
=
WQLinear
(
q_proj
.
w_bit
,
...
...
@@ -118,7 +119,7 @@ class LlamaFuser:
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
next
(
iter
(
module
.
state_dict
().
values
()))
.
device
)
# replace buffers with real weights
...
...
awq/modules/fused_attn.py
View file @
198ba2fb
...
...
@@ -2,6 +2,7 @@ import torch
import
torch.nn
as
nn
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
transformers.models.llama.modeling_llama
import
apply_rotary_pos_emb
,
LlamaRotaryEmbedding
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
...
@@ -29,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# [max_position, rot_dim]
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
...
...
@@ -41,13 +43,16 @@ class QuantLlamaRotaryEmbedding(nn.Module):
# to the attention op.
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding_neox
(
awq_inference_engine
.
rotary_embedding
(
positions
,
query
,
key
,
self
.
dim
,
self
.
cos_sin_cache
,
True
# is_neox
)
return
query
,
key
class
QuantLlamaAttention
(
nn
.
Module
):
...
...
@@ -57,22 +62,30 @@ class QuantLlamaAttention(nn.Module):
self
,
hidden_size
,
num_heads
,
num_kv_heads
,
qkv_proj
,
o_proj
,
dev
,
max_new_tokens
max_new_tokens
,
use_hf_rotary
=
False
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
use_hf_rotary
=
use_hf_rotary
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
if
use_hf_rotary
:
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
,
max_new_tokens
,
device
=
dev
)
else
:
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
"""Input shape: Batch x Time x Channel"""
...
...
@@ -80,42 +93,72 @@ class QuantLlamaAttention(nn.Module):
bsz
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
# This updates the query and key states in-place, saving VRAM.
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
if
self
.
use_hf_rotary
:
# get qkv
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
query
,
key
,
value
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
del
qkv_states
# reshape for hf rotary
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value
,
seq_len
=
kv_seq_len
)
query
,
key
=
apply_rotary_pos_emb
(
query
,
key
,
cos
,
sin
,
position_ids
)
else
:
# get qkv
query
,
key
,
value
=
qkv_states
.
chunk
(
chunks
=
3
,
dim
=-
1
)
del
qkv_states
# [num_tokens, num_heads * head_size]
query_batch_size
,
query_len
,
_
=
query
.
shape
query
=
query
.
view
(
query_len
*
query_batch_size
,
self
.
num_heads
*
self
.
head_dim
)
# [num_tokens, num_kv_heads * head_size]
key_batch_size
,
key_len
,
_
=
key
.
shape
key
=
key
.
view
(
key_len
*
key_batch_size
,
self
.
num_kv_heads
*
self
.
head_dim
)
# [num_tokens]
positions
=
position_ids
.
view
(
-
1
).
to
(
query
.
device
)
query
,
key
=
self
.
rotary_emb
(
query
,
key
,
positions
)
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
del
qkv_states
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
is_causal
=
past_key_value
is
None
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
value
_states
=
value
_states
.
to
(
key_states
.
device
)
value
=
value
.
to
(
key
.
device
)
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key
_states
=
torch
.
cat
([
past_key_value
[
0
],
key
_states
],
dim
=
2
)
value
_states
=
torch
.
cat
([
past_key_value
[
1
],
value
_states
],
dim
=
2
)
key
=
torch
.
cat
([
past_key_value
[
0
],
key
],
dim
=
2
)
value
=
torch
.
cat
([
past_key_value
[
1
],
value
],
dim
=
2
)
if
use_cache
:
# Since qkv_proj is fused, query
_states
etc will hold a reference to the original qkv_states tensor
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key
_states
=
key_states
.
contiguous
()
value
_states
=
value
_states
.
contiguous
()
query
_states
=
query
_states
.
contiguous
()
key
=
key
.
contiguous
()
value
=
value
.
contiguous
()
query
=
query
.
contiguous
()
past_key_value
=
(
key
_states
,
value_states
)
if
use_cache
else
None
past_key_value
=
(
key
,
value
)
if
use_cache
else
None
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output
=
F
.
scaled_dot_product_attention
(
query
_states
,
key_states
,
value_states
,
is_causal
=
is_causal
)
del
query
_states
,
key_states
,
value_states
attn_output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
is_causal
)
del
query
,
key
,
value
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
awq_cuda/position_embedding/pos_encoding.h
View file @
198ba2fb
#pragma once
#include <torch/extension.h>
void
rotary_embedding
_neox
(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
);
\ No newline at end of file
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
\ No newline at end of file
awq_cuda/position_embedding/pos_encoding_kernels.cu
View file @
198ba2fb
...
...
@@ -9,15 +9,56 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
// GPT-NeoX style rotary embedding.
x_index
=
rot_offset
;
y_index
=
embed_dim
+
rot_offset
;
cos
=
__ldg
(
cos_ptr
+
x_index
);
sin
=
__ldg
(
sin_ptr
+
x_index
);
}
else
{
// GPT-J style rotary embedding.
x_index
=
2
*
rot_offset
;
y_index
=
2
*
rot_offset
+
1
;
cos
=
__ldg
(
cos_ptr
+
x_index
/
2
);
sin
=
__ldg
(
sin_ptr
+
x_index
/
2
);
}
const
scalar_t
x
=
arr
[
x_index
];
const
scalar_t
y
=
arr
[
y_index
];
arr
[
x_index
]
=
x
*
cos
-
y
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_
kv_
heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
stride
,
const
int
query_stride
,
const
int
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
...
...
@@ -25,64 +66,72 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
void
rotary_embedding_neox
(
torch
::
Tensor
&
positions
,
// [b, num_tokens]
torch
::
Tensor
&
query
,
// [b, num_tokens, 1, num_heads, head_size]
torch
::
Tensor
&
key
,
// [b, num_tokens, 1, num_heads, head_size]
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
{
int
num_tokens
=
query
.
size
(
0
)
*
query
.
size
(
1
)
;
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
2
);
int
stride
=
num_heads
*
head_size
;
// TORCH_CHECK(stride == key.stride(0));
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
query_stride
=
query
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding
_neox
"
,
"rotary_embedding"
,
[
&
]
{
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
stride
,
num_heads
,
head_size
);
if
(
is_neox
)
{
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
}
\ No newline at end of file
awq_cuda/pybind.cpp
View file @
198ba2fb
...
...
@@ -8,5 +8,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"rotary_embedding
_neox
"
,
&
rotary_embedding
_neox
,
"Apply
GPT-NeoX style
rotary embedding to query and key"
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply rotary embedding to query and key"
);
}
setup.py
View file @
198ba2fb
...
...
@@ -85,9 +85,11 @@ if os.name == "nt":
"nvcc"
:
arch_flags
}
else
:
threads
=
[
"--threads"
,
str
(
min
(
os
.
cpu_count
(),
8
))]
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
]
+
arch_flags
"nvcc"
:
[
"-O3"
,
"-std=c++17"
]
+
arch_flags
+
threads
}
extensions
=
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment