Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
97af18e4
Commit
97af18e4
authored
Sep 06, 2023
by
Casper Hansen
Browse files
Update rotary embedding kernel
parent
54b63712
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
102 additions
and
52 deletions
+102
-52
awq_cuda/position_embedding/pos_encoding.h
awq_cuda/position_embedding/pos_encoding.h
+3
-2
awq_cuda/position_embedding/pos_encoding_kernels.cu
awq_cuda/position_embedding/pos_encoding_kernels.cu
+98
-49
awq_cuda/pybind.cpp
awq_cuda/pybind.cpp
+1
-1
No files found.
awq_cuda/position_embedding/pos_encoding.h
View file @
97af18e4
#pragma once
#pragma once
#include <torch/extension.h>
#include <torch/extension.h>
void
rotary_embedding
_neox
(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
torch
::
Tensor
&
key
,
int
head_size
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
);
torch
::
Tensor
&
cos_sin_cache
,
\ No newline at end of file
bool
is_neox
);
\ No newline at end of file
awq_cuda/position_embedding/pos_encoding_kernels.cu
View file @
97af18e4
...
@@ -9,15 +9,56 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
...
@@ -9,15 +9,56 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
#include "pos_encoding.h"
template
<
typename
scalar_t
>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
__global__
void
rotary_embedding_neox_kernel
(
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
// GPT-NeoX style rotary embedding.
x_index
=
rot_offset
;
y_index
=
embed_dim
+
rot_offset
;
cos
=
__ldg
(
cos_ptr
+
x_index
);
sin
=
__ldg
(
sin_ptr
+
x_index
);
}
else
{
// GPT-J style rotary embedding.
x_index
=
2
*
rot_offset
;
y_index
=
2
*
rot_offset
+
1
;
cos
=
__ldg
(
cos_ptr
+
x_index
/
2
);
sin
=
__ldg
(
sin_ptr
+
x_index
/
2
);
}
const
scalar_t
x
=
arr
[
x_index
];
const
scalar_t
y
=
arr
[
y_index
];
arr
[
x_index
]
=
x
*
cos
-
y
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_
kv_
heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
rot_dim
,
const
int
stride
,
const
int
query_stride
,
const
int
key_stride
,
const
int
num_heads
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
...
@@ -25,64 +66,72 @@ __global__ void rotary_embedding_neox_kernel(
...
@@ -25,64 +66,72 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
n
=
num_heads
*
embed_dim
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
const
int
y_index
=
embed_dim
+
rot_offset
;
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
int
nk
=
num_kv_heads
*
embed_dim
;
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
const
int
head_idx
=
i
/
embed_dim
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
}
}
void
rotary_embedding
(
void
rotary_embedding_neox
(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
positions
,
// [b, num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
query
,
// [b, num_tokens, 1, num_heads, head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
torch
::
Tensor
&
key
,
// [b, num_tokens, 1, num_heads, head_size]
int
head_size
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
{
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
)
*
query
.
size
(
1
);
int
num_tokens
=
query
.
size
(
0
)
*
query
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
2
);
int
num_heads
=
query
.
size
(
-
2
);
int
stride
=
num_heads
*
head_size
;
int
num_kv_heads
=
key
.
size
(
-
2
);
// TORCH_CHECK(stride == key.stride(0));
int
query_stride
=
query
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
VLLM_DISPATCH_FLOATING_TYPES
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
query
.
scalar_type
(),
"rotary_embedding
_neox
"
,
"rotary_embedding"
,
[
&
]
{
[
&
]
{
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
if
(
is_neox
)
{
positions
.
data_ptr
<
int64_t
>
(),
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
query
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
stride
,
rot_dim
,
num_heads
,
query_stride
,
head_size
);
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
});
}
}
\ No newline at end of file
awq_cuda/pybind.cpp
View file @
97af18e4
...
@@ -8,5 +8,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
...
@@ -8,5 +8,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"rotary_embedding
_neox
"
,
&
rotary_embedding
_neox
,
"Apply
GPT-NeoX style
rotary embedding to query and key"
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply rotary embedding to query and key"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment