Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
dd8691b7
Unverified
Commit
dd8691b7
authored
Sep 24, 2024
by
Nicolas Patry
Committed by
GitHub
Sep 24, 2024
Browse files
More tensor cores. (#2558)
* More tensor cores. * Fixing the logic. * Gemma is modified by this.
parent
c032280b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
42 deletions
+46
-42
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json
...dels/__snapshots__/test_flash_gemma/test_flash_gemma.json
+16
-16
integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json
...st_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json
+24
-24
server/text_generation_server/layers/attention/flashinfer.py
server/text_generation_server/layers/attention/flashinfer.py
+6
-2
No files found.
integration-tests/models/__snapshots__/test_flash_gemma/test_flash_gemma.json
View file @
dd8691b7
...
@@ -24,13 +24,13 @@
...
@@ -24,13 +24,13 @@
"tokens"
:
[
"tokens"
:
[
{
{
"id"
:
1736
,
"id"
:
1736
,
"logprob"
:
-2.
0312
5
,
"logprob"
:
-2.
10937
5
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" form"
"text"
:
" form"
},
},
{
{
"id"
:
109
,
"id"
:
109
,
"logprob"
:
-1.
867187
5
,
"logprob"
:
-1.
9062
5
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
\n\n
"
"text"
:
"
\n\n
"
},
},
...
@@ -42,48 +42,48 @@
...
@@ -42,48 +42,48 @@
},
},
{
{
"id"
:
2121
,
"id"
:
2121
,
"logprob"
:
-1.
812
5
,
"logprob"
:
-1.
79687
5
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" test"
"text"
:
" test"
},
},
{
{
"id"
:
3853
,
"id"
:
3853
,
"logprob"
:
-0.24
121094
,
"logprob"
:
-0.24
511719
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" request"
"text"
:
" request"
},
},
{
{
"id"
:
1736
,
"id"
:
1736
,
"logprob"
:
-0.
100097656
,
"logprob"
:
-0.
09326172
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" form"
"text"
:
" form"
},
},
{
{
"id"
:
603
,
"id"
:
603
,
"logprob"
:
-0.9
4
53125
,
"logprob"
:
-0.95
70
3125
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" is"
"text"
:
" is"
},
},
{
{
"id"
:
476
,
"id"
:
1671
,
"logprob"
:
-1.
70312
5
,
"logprob"
:
-1.
585937
5
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
a
"
"text"
:
"
used
"
},
},
{
{
"id"
:
4551
,
"id"
:
577
,
"logprob"
:
-
2.453125
,
"logprob"
:
-
0.39257812
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
documen
t"
"text"
:
" t
o
"
},
},
{
{
"id"
:
674
,
"id"
:
3853
,
"logprob"
:
-
0.79687
5
,
"logprob"
:
-
1.2
5
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
tha
t"
"text"
:
"
reques
t"
}
}
],
],
"top_tokens"
:
null
"top_tokens"
:
null
},
},
"generated_text"
:
" form
\n\n
The test request form is
a document tha
t"
"generated_text"
:
" form
\n\n
The test request form is
used to reques
t"
}
}
integration-tests/models/__snapshots__/test_flash_gemma_gptq/test_flash_gemma_gptq_all_params.json
View file @
dd8691b7
...
@@ -11,12 +11,12 @@
...
@@ -11,12 +11,12 @@
},
},
{
{
"id"
:
2015
,
"id"
:
2015
,
"logprob"
:
-9.64
062
5
,
"logprob"
:
-9.64
8437
5
,
"text"
:
"Test"
"text"
:
"Test"
},
},
{
{
"id"
:
3853
,
"id"
:
3853
,
"logprob"
:
-10.375
,
"logprob"
:
-10.3
6718
75
,
"text"
:
" request"
"text"
:
" request"
}
}
],
],
...
@@ -24,19 +24,19 @@
...
@@ -24,19 +24,19 @@
"tokens"
:
[
"tokens"
:
[
{
{
"id"
:
604
,
"id"
:
604
,
"logprob"
:
-0.282
4707
,
"logprob"
:
-0.282
71484
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" for"
"text"
:
" for"
},
},
{
{
"id"
:
573
,
"id"
:
573
,
"logprob"
:
-0.1
903076
2
,
"logprob"
:
-0.1
849365
2
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" the"
"text"
:
" the"
},
},
{
{
"id"
:
16819
,
"id"
:
16819
,
"logprob"
:
-1.48
9257
8
,
"logprob"
:
-1.48
0468
8
,
"special"
:
false
,
"special"
:
false
,
"text"
:
" detection"
"text"
:
" detection"
},
},
...
@@ -47,43 +47,43 @@
...
@@ -47,43 +47,43 @@
"text"
:
" of"
"text"
:
" of"
},
},
{
{
"id"
:
573
,
"id"
:
671
,
"logprob"
:
-2.
0195312
,
"logprob"
:
-2.
1738281
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
the
"
"text"
:
"
an
"
},
},
{
{
"id"
:
856
6
,
"id"
:
2464
6
,
"logprob"
:
0.0
,
"logprob"
:
-3.0449219
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
presence
"
"text"
:
"
RNA
"
},
},
{
{
"id"
:
6
8
9
,
"id"
:
123
69
,
"logprob"
:
-0.1
6491699
,
"logprob"
:
-0.1
9299316
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
or
"
"text"
:
"
virus
"
},
},
{
{
"id"
:
14862
,
"id"
:
575
,
"logprob"
:
0.
0
,
"logprob"
:
-
0.
10632324
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
absence
"
"text"
:
"
in
"
},
},
{
{
"id"
:
576
,
"id"
:
6022
,
"logprob"
:
-0.9
946289
,
"logprob"
:
-0.9
8095703
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
of
"
"text"
:
"
patients
"
},
},
{
{
"id"
:
671
,
"id"
:
1064
,
"logprob"
:
-
0.5263672
,
"logprob"
:
-
1.3095703
,
"special"
:
false
,
"special"
:
false
,
"text"
:
"
an
"
"text"
:
"
who
"
}
}
],
],
"top_tokens"
:
null
"top_tokens"
:
null
},
},
"generated_text"
:
"Test request for the detection of
the presence or absence of an
"
"generated_text"
:
"Test request for the detection of
an RNA virus in patients who
"
}
}
server/text_generation_server/layers/attention/flashinfer.py
View file @
dd8691b7
...
@@ -152,11 +152,13 @@ def create_decode_state(
...
@@ -152,11 +152,13 @@ def create_decode_state(
):
):
"""Create a decode state."""
"""Create a decode state."""
workspace_buffer
=
get_workspace
(
device
)
workspace_buffer
=
get_workspace
(
device
)
num_groups
=
num_heads
//
num_kv_heads
return
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
return
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
workspace_buffer
,
kv_layout
=
"NHD"
,
kv_layout
=
"NHD"
,
use_cuda_graph
=
False
,
use_cuda_graph
=
False
,
use_tensor_cores
=
num_heads
//
num_kv_heads
>
4
,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores
=
num_groups
not
in
[
1
,
2
,
4
,
8
],
)
)
...
@@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
...
@@ -175,6 +177,7 @@ def create_decode_state_cuda_graphs(
therefore stored as part of the state.
therefore stored as part of the state.
"""
"""
workspace_buffer
=
get_workspace
(
device
)
workspace_buffer
=
get_workspace
(
device
)
num_groups
=
num_heads
//
num_kv_heads
return
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
return
flashinfer
.
BatchDecodeWithPagedKVCacheWrapper
(
workspace_buffer
,
workspace_buffer
,
kv_layout
=
"NHD"
,
kv_layout
=
"NHD"
,
...
@@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
...
@@ -182,7 +185,8 @@ def create_decode_state_cuda_graphs(
paged_kv_indices_buffer
=
block_tables
,
paged_kv_indices_buffer
=
block_tables
,
paged_kv_indptr_buffer
=
block_tables_ptr
,
paged_kv_indptr_buffer
=
block_tables_ptr
,
paged_kv_last_page_len_buffer
=
last_page_len
,
paged_kv_last_page_len_buffer
=
last_page_len
,
use_tensor_cores
=
num_heads
//
num_kv_heads
>
4
,
# Taken from https://github.com/flashinfer-ai/flashinfer/blob/33ef95700981ba70f4cab63b8931e562bc795b21/python/flashinfer/decode.py#L57-L60
use_tensor_cores
=
num_groups
not
in
[
1
,
2
,
4
,
8
],
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment