Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
bbe1d46a
Commit
bbe1d46a
authored
Sep 05, 2023
by
Casper Hansen
Browse files
Revert "Update neox kernel"
This reverts commit
4fe9974a
.
parent
4fe9974a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
142 additions
and
106 deletions
+142
-106
awq/models/llama.py
awq/models/llama.py
+2
-2
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+125
-60
awq_cuda/position_embedding/pos_encoding_kernels.cu
awq_cuda/position_embedding/pos_encoding_kernels.cu
+15
-44
No files found.
awq/models/llama.py
View file @
bbe1d46a
...
@@ -71,7 +71,7 @@ from awq.quantize.qmodule import WQLinear
...
@@ -71,7 +71,7 @@ from awq.quantize.qmodule import WQLinear
from
awq.utils.utils
import
set_module_name
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_mlp
import
QuantLlamaMLP
from
awq.modules.fused_mlp
import
QuantLlamaMLP
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
awq.modules.fused_attn
import
QuantLlamaAttention
,
CustomQuantLlamaAttention
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
class
LlamaFuser
:
class
LlamaFuser
:
...
@@ -96,7 +96,7 @@ class LlamaFuser:
...
@@ -96,7 +96,7 @@ class LlamaFuser:
def
fuse_attention
(
self
):
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
WQLinear
=
self
.
_fuse_qkv
(
module
)
qkv_layer
:
WQLinear
=
self
.
_fuse_qkv
(
module
)
attn
=
QuantLlamaAttention
(
attn
=
Custom
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_heads
,
qkv_layer
,
qkv_layer
,
...
...
awq/modules/fused_attn.py
View file @
bbe1d46a
...
@@ -2,23 +2,22 @@ import torch
...
@@ -2,23 +2,22 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
transformers.models.llama.modeling_llama
import
LlamaLinearScalingRotaryEmbedding
class
RotaryEmbedding
Neox
(
nn
.
Module
):
class
QuantLlama
RotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
head_dim
,
seq_len
,
device
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
head_dim
=
head_dim
self
.
seq_len
=
seq_len
self
.
base
=
10000
# create inv_frequency
self
.
dim
=
dim
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
head_dim
,
2
).
float
().
to
(
device
)
/
self
.
head_dim
))
self
.
max_position_embeddings
=
max_position_embeddings
self
.
base
=
base
inv_freq
=
1.0
/
(
self
.
base
**
(
torch
.
arange
(
0
,
self
.
dim
,
2
).
float
().
to
(
device
)
/
self
.
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
# Build here to make `torch.jit.trace` work.
self
.
_set_cos_sin_cache
(
seq_len
=
max_position_embeddings
,
device
=
self
.
inv_freq
.
device
,
dtype
=
torch
.
get_default_dtype
()
)
# set cache
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
,
dtype
):
self
.
_set_cos_sin_cache
(
seq_len
=
self
.
seq_len
,
device
=
self
.
inv_freq
.
device
)
def
_set_cos_sin_cache
(
self
,
seq_len
,
device
):
self
.
max_seq_len_cached
=
seq_len
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
device
,
dtype
=
self
.
inv_freq
.
dtype
)
...
@@ -32,89 +31,122 @@ class RotaryEmbeddingNeox(nn.Module):
...
@@ -32,89 +31,122 @@ class RotaryEmbeddingNeox(nn.Module):
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
self
,
positions
,
query
,
key
):
def
forward
(
batch_size
,
seq_len
,
_
=
query
.
shape
self
,
query
=
query
.
view
(
batch_size
*
seq_len
,
-
1
)
query
:
torch
.
Tensor
,
key
=
key
.
view
(
batch_size
*
seq_len
,
-
1
)
key
:
torch
.
Tensor
,
positions
=
positions
.
view
(
-
1
).
to
(
query
.
device
)
positions
:
torch
.
Tensor
,
):
# Apply rotary embedding to the query and key before passing them
# to the attention op.
query
=
query
.
contiguous
()
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding_neox
(
awq_inference_engine
.
rotary_embedding_neox
(
positions
,
positions
,
query
,
query
,
key
,
key
,
self
.
head_
dim
,
self
.
dim
,
self
.
cos_sin_cache
,
self
.
cos_sin_cache
,
)
)
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
)
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
)
return
query
,
key
return
query
,
key
class
QuantLlamaAttention
(
nn
.
Module
):
class
QuantLlamaAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
def
__init__
(
self
,
self
,
hidden_size
,
hidden_size
,
num_heads
,
num_heads
,
qkv_proj
,
qkv_proj
,
o_proj
,
o_proj
,
dev
ice
,
dev
,
max_new_tokens
max_new_tokens
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
seq_len
=
max_new_tokens
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
rotary_embedding_neox
=
RotaryEmbeddingNeox
(
self
.
head_dim
,
self
.
seq_len
,
device
)
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
"""Input shape: Batch x Time x Channel"""
def
attn
(
self
,
query
,
key
,
value
,
past_key_value
,
use_cache
,
attention_mask
):
bsz
,
q_len
,
_
=
hidden_states
.
size
()
batch_size
,
seq_len
,
_
=
query
.
shape
query
=
query
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
key
=
key
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
value
=
value
.
view
(
batch_size
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# This updates the query and key states in-place, saving VRAM.
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
value
=
value
.
to
(
key
.
device
)
del
qkv_states
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# cache ops
is_causal
=
past_key_value
is
None
is_causal
=
past_key_value
is
None
kv_seq_len
=
q_len
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
value_states
=
value_states
.
to
(
key_states
.
device
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
# reuse k, v, self_attention
key
=
torch
.
cat
([
past_key_value
[
0
],
key
],
dim
=
2
)
key
_states
=
torch
.
cat
([
past_key_value
[
0
],
key
_states
],
dim
=
2
)
value
=
torch
.
cat
([
past_key_value
[
1
],
value
],
dim
=
2
)
value
_states
=
torch
.
cat
([
past_key_value
[
1
],
value
_states
],
dim
=
2
)
if
use_cache
:
if
use_cache
:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv tensor
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv
_states
tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query
=
query
.
contiguous
()
key_states
=
key_states
.
contiguous
()
key
=
key
.
contiguous
()
value_states
=
value_states
.
contiguous
()
value
=
value
.
contiguous
()
query_states
=
query_states
.
contiguous
()
past_key_value
=
(
key
,
value
)
if
use_cache
else
None
past_key_value
=
(
key
_states
,
value_states
)
if
use_cache
else
None
# multi-head masked attention
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output
=
F
.
scaled_dot_product_attention
(
attn_output
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
is_causal
=
is_causal
)
query
,
del
query_states
,
key_states
,
value_states
key
,
value
,
attn_mask
=
None
if
is_causal
else
attention_mask
,
is_causal
=
is_causal
)
# reshape output
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
attn_output
.
reshape
(
batch_size
,
seq_len
,
self
.
hidden_size
)
return
attn_output
,
None
,
past_key_value
class
CustomQuantLlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_proj
,
o_proj
,
dev
,
max_new_tokens
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
return
attn_output
,
past_key_value
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -126,13 +158,46 @@ class QuantLlamaAttention(nn.Module):
...
@@ -126,13 +158,46 @@ class QuantLlamaAttention(nn.Module):
use_cache
:
bool
=
False
,
use_cache
:
bool
=
False
,
):
):
# qkv proj
# qkv proj
q
uery
,
key
,
value
=
self
.
qkv_proj
(
hidden_states
)
.
chunk
(
chunks
=
3
,
dim
=-
1
)
q
kv_states
=
self
.
qkv_proj
(
hidden_states
)
# rotary embeddings
# extract q,k,v
query
,
key
=
self
.
rotary_embedding_neox
(
position_ids
,
query
,
key
)
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
self
.
hidden_size
,
dim
=
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# attention
# rotary embedding
attn_output
,
past_key_value
=
self
.
attn
(
query
,
key
,
value
,
past_key_value
,
use_cache
,
attention_mask
)
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
# cache ops
is_causal
=
past_key_value
is
None
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
if
use_cache
:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
# multi-head masked attention
attn_output
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
None
if
is_causal
else
attention_mask
,
is_causal
=
is_causal
)
# reshape output
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
# out projection
# out projection
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
awq_cuda/position_embedding/pos_encoding_kernels.cu
View file @
bbe1d46a
...
@@ -9,26 +9,15 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
...
@@ -9,26 +9,15 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
#include "pos_encoding.h"
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
__global__
void
rotary_embedding_neox_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_
kv_
heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
rot_dim
,
const
int
query_stride
,
const
int
stride
,
const
int
key_stride
,
const
int
num_heads
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
...
@@ -36,17 +25,17 @@ __global__ void rotary_embedding_neox_kernel(
...
@@ -36,17 +25,17 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
n
q
=
num_heads
*
embed_dim
;
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
q
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
query_
stride
+
head_idx
*
head_size
;
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
query_
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
query_
stride
+
head_idx
*
head_size
+
y_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
...
@@ -55,22 +44,6 @@ __global__ void rotary_embedding_neox_kernel(
...
@@ -55,22 +44,6 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
key_stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
key_stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
...
@@ -86,17 +59,18 @@ void rotary_embedding_neox(
...
@@ -86,17 +59,18 @@ void rotary_embedding_neox(
int
head_size
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
{
{
int
num_tokens
=
query
.
size
(
0
);
int
num_tokens
=
query
.
size
(
0
)
*
query
.
size
(
1
)
;
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
2
);
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
stride
=
num_heads
*
head_size
;
int
query_stride
=
query
.
stride
(
0
);
// TORCH_CHECK(stride == key.stride(0));
int
key_stride
=
key
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
query
.
scalar_type
(),
"rotary_embedding_neox"
,
"rotary_embedding_neox"
,
[
&
]
{
[
&
]
{
...
@@ -106,12 +80,9 @@ void rotary_embedding_neox(
...
@@ -106,12 +80,9 @@ void rotary_embedding_neox(
key
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
rot_dim
,
query_stride
,
stride
,
key_stride
,
num_heads
,
num_heads
,
num_kv_heads
,
head_size
);
head_size
);
});
});
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment