Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
1028996f
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "3ad49eeeddc5b3a82540bd37ac133650d02ad93d"
Unverified
Commit
1028996f
authored
Sep 28, 2024
by
Daniël de Kok
Committed by
GitHub
Sep 28, 2024
Browse files
flashinfer: pass window size and dtype (#2574)
parent
5b6b74e2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
6 deletions
+17
-6
server/text_generation_server/layers/attention/flashinfer.py
server/text_generation_server/layers/attention/flashinfer.py
+13
-6
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+4
-0
No files found.
server/text_generation_server/layers/attention/flashinfer.py
View file @
1028996f
...
@@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
...
@@ -50,7 +50,8 @@ def use_prefill_with_paged_kv_state(
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
page_size
:
int
,
page_size
:
int
,
query_dtype
:
str
=
"float16"
,
dtype
:
torch
.
dtype
,
window_left
:
int
,
):
):
"""
"""
Context manager to set the active flashinfer prefill state to the given
Context manager to set the active flashinfer prefill state to the given
...
@@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
...
@@ -90,8 +91,9 @@ def use_prefill_with_paged_kv_state(
num_qo_heads
=
num_heads
,
num_qo_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_size
,
head_dim
=
head_size
,
q_data_type
=
query_
dtype
,
q_data_type
=
dtype
,
page_size
=
page_size
,
page_size
=
page_size
,
window_left
=
window_left
,
)
)
yield
yield
finally
:
finally
:
...
@@ -119,7 +121,8 @@ def use_prefill_state(
...
@@ -119,7 +121,8 @@ def use_prefill_state(
num_heads
:
int
,
num_heads
:
int
,
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
query_dtype
:
str
=
"float16"
,
dtype
:
torch
.
dtype
,
window_left
:
int
,
):
):
"""
"""
Context manager to set the active flashinfer prefill state to the given
Context manager to set the active flashinfer prefill state to the given
...
@@ -135,7 +138,8 @@ def use_prefill_state(
...
@@ -135,7 +138,8 @@ def use_prefill_state(
num_qo_heads
=
num_heads
,
num_qo_heads
=
num_heads
,
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_size
,
head_dim
=
head_size
,
q_data_type
=
query_dtype
,
q_data_type
=
dtype
,
window_left
=
window_left
,
)
)
yield
yield
finally
:
finally
:
...
@@ -200,7 +204,8 @@ def use_decode_state(
...
@@ -200,7 +204,8 @@ def use_decode_state(
num_kv_heads
:
int
,
num_kv_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
page_size
:
int
,
page_size
:
int
,
query_dtype
:
str
=
"float16"
,
dtype
:
torch
.
dtype
,
window_left
:
int
,
):
):
"""
"""
Context manager to set the active flashinfer decoding state to the given
Context manager to set the active flashinfer decoding state to the given
...
@@ -235,7 +240,9 @@ def use_decode_state(
...
@@ -235,7 +240,9 @@ def use_decode_state(
num_kv_heads
=
num_kv_heads
,
num_kv_heads
=
num_kv_heads
,
head_dim
=
head_size
,
head_dim
=
head_size
,
page_size
=
page_size
,
page_size
=
page_size
,
q_data_type
=
query_dtype
,
data_type
=
dtype
,
q_data_type
=
dtype
,
window_left
=
window_left
,
)
)
yield
yield
finally
:
finally
:
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
1028996f
...
@@ -1960,6 +1960,8 @@ class FlashCausalLM(Model):
...
@@ -1960,6 +1960,8 @@ class FlashCausalLM(Model):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
head_size
=
self
.
head_size
,
page_size
=
BLOCK_SIZE
,
page_size
=
BLOCK_SIZE
,
dtype
=
self
.
dtype
,
window_left
=
self
.
sliding_window
,
)
)
else
:
else
:
assert
input_lengths_tensor
is
not
None
assert
input_lengths_tensor
is
not
None
...
@@ -1971,6 +1973,8 @@ class FlashCausalLM(Model):
...
@@ -1971,6 +1973,8 @@ class FlashCausalLM(Model):
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
head_size
=
self
.
head_size
,
head_size
=
self
.
head_size
,
page_size
=
BLOCK_SIZE
,
page_size
=
BLOCK_SIZE
,
dtype
=
self
.
dtype
,
window_left
=
self
.
sliding_window
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment