Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a22dea54
Unverified
Commit
a22dea54
authored
May 31, 2024
by
SnowDist
Committed by
GitHub
May 30, 2024
Browse files
[Model] Support MAP-NEO model (#5081)
Co-authored-by:
Zhuohan Li
<
zhuohan123@gmail.com
>
parent
533c2177
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
18 additions
and
6 deletions
+18
-6
benchmarks/kernels/benchmark_paged_attention.py
benchmarks/kernels/benchmark_paged_attention.py
+1
-1
benchmarks/kernels/benchmark_rope.py
benchmarks/kernels/benchmark_rope.py
+1
-1
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+6
-0
csrc/cpu/attention.cpp
csrc/cpu/attention.cpp
+6
-0
tests/kernels/test_attention.py
tests/kernels/test_attention.py
+1
-1
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+1
-1
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+1
-1
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+1
-1
No files found.
benchmarks/kernels/benchmark_paged_attention.py
View file @
a22dea54
...
@@ -170,7 +170,7 @@ if __name__ == '__main__':
...
@@ -170,7 +170,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--num-kv-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
128
,
256
],
choices
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
],
default
=
128
)
default
=
128
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--block-size"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
16
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
parser
.
add_argument
(
"--use-alibi"
,
action
=
"store_true"
)
...
...
benchmarks/kernels/benchmark_rope.py
View file @
a22dea54
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
...
@@ -93,7 +93,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--num-heads"
,
type
=
int
,
default
=
8
)
parser
.
add_argument
(
"--head-size"
,
parser
.
add_argument
(
"--head-size"
,
type
=
int
,
type
=
int
,
choices
=
[
64
,
80
,
96
,
112
,
128
,
256
],
choices
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
],
default
=
128
)
default
=
128
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--rotary-dim"
,
type
=
int
,
choices
=
[
16
,
32
],
default
=
32
)
parser
.
add_argument
(
"--dtype"
,
parser
.
add_argument
(
"--dtype"
,
...
...
csrc/attention/attention_kernels.cu
View file @
a22dea54
...
@@ -754,6 +754,9 @@ void paged_attention_v1_launcher(
...
@@ -754,6 +754,9 @@ void paged_attention_v1_launcher(
case
128
:
case
128
:
LAUNCH_PAGED_ATTENTION_V1
(
128
);
LAUNCH_PAGED_ATTENTION_V1
(
128
);
break
;
break
;
case
192
:
LAUNCH_PAGED_ATTENTION_V1
(
192
);
break
;
case
256
:
case
256
:
LAUNCH_PAGED_ATTENTION_V1
(
256
);
LAUNCH_PAGED_ATTENTION_V1
(
256
);
break
;
break
;
...
@@ -911,6 +914,9 @@ void paged_attention_v2_launcher(
...
@@ -911,6 +914,9 @@ void paged_attention_v2_launcher(
case
128
:
case
128
:
LAUNCH_PAGED_ATTENTION_V2
(
128
);
LAUNCH_PAGED_ATTENTION_V2
(
128
);
break
;
break
;
case
192
:
LAUNCH_PAGED_ATTENTION_V2
(
192
);
break
;
case
256
:
case
256
:
LAUNCH_PAGED_ATTENTION_V2
(
256
);
LAUNCH_PAGED_ATTENTION_V2
(
256
);
break
;
break
;
...
...
csrc/cpu/attention.cpp
View file @
a22dea54
...
@@ -390,6 +390,9 @@ void paged_attention_v1_impl_launcher(
...
@@ -390,6 +390,9 @@ void paged_attention_v1_impl_launcher(
case
128
:
case
128
:
LAUNCH_V1_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
);
LAUNCH_V1_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
);
break
;
break
;
case
192
:
LAUNCH_V1_ATTENTION_KERNEL
(
T
,
192
,
BLOCK_SIZE
);
break
;
case
256
:
case
256
:
LAUNCH_V1_ATTENTION_KERNEL
(
T
,
256
,
BLOCK_SIZE
);
LAUNCH_V1_ATTENTION_KERNEL
(
T
,
256
,
BLOCK_SIZE
);
break
;
break
;
...
@@ -703,6 +706,9 @@ void paged_attention_v2_impl_launcher(
...
@@ -703,6 +706,9 @@ void paged_attention_v2_impl_launcher(
case
128
:
case
128
:
LAUNCH_V2_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
);
LAUNCH_V2_ATTENTION_KERNEL
(
T
,
128
,
BLOCK_SIZE
);
break
;
break
;
case
192
:
LAUNCH_V2_ATTENTION_KERNEL
(
T
,
192
,
BLOCK_SIZE
);
break
;
case
256
:
case
256
:
LAUNCH_V2_ATTENTION_KERNEL
(
T
,
256
,
BLOCK_SIZE
);
LAUNCH_V2_ATTENTION_KERNEL
(
T
,
256
,
BLOCK_SIZE
);
break
;
break
;
...
...
tests/kernels/test_attention.py
View file @
a22dea54
...
@@ -28,7 +28,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
...
@@ -28,7 +28,7 @@ NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
# FlashAttention forward only supports head dimension at most 128
# FlashAttention forward only supports head dimension at most 128
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
if
not
is_hip
()
else
[
64
,
80
,
96
,
112
,
128
]
]
if
not
is_hip
()
else
[
64
,
80
,
96
,
112
,
128
]
BLOCK_SIZES
=
[
16
,
32
]
BLOCK_SIZES
=
[
16
,
32
]
...
...
tests/kernels/test_cache.py
View file @
a22dea54
...
@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
...
@@ -11,7 +11,7 @@ DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_TOKENS
=
[
42
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_LAYERS
=
[
1
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
NUM_HEADS
=
[
8
]
# Arbitrary values for testing
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
BLOCK_SIZES
=
[
8
,
16
,
32
]
# Arbitrary values for testing
# Arbitrary values for testing
...
...
tests/kernels/test_pos_encoding.py
View file @
a22dea54
...
@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
...
@@ -10,7 +10,7 @@ from .allclose_default import get_default_atol, get_default_rtol
IS_NEOX_STYLE
=
[
True
,
False
]
IS_NEOX_STYLE
=
[
True
,
False
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
256
]
HEAD_SIZES
=
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
ROTARY_DIMS
=
[
None
,
32
]
# None means rotary dim == head size
NUM_HEADS
=
[
7
,
17
]
# Arbitrary values for testing
NUM_HEADS
=
[
7
,
17
]
# Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
BATCH_SIZES
=
[
1
,
5
]
# Arbitrary values for testing
...
...
vllm/attention/ops/paged_attn.py
View file @
a22dea54
...
@@ -31,7 +31,7 @@ class PagedAttention:
...
@@ -31,7 +31,7 @@ class PagedAttention:
@
staticmethod
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
64
,
80
,
96
,
112
,
128
,
256
]
return
[
64
,
80
,
96
,
112
,
128
,
192
,
256
]
@
staticmethod
@
staticmethod
def
get_kv_cache_shape
(
def
get_kv_cache_shape
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment