Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
ad66f6ef
Unverified
Commit
ad66f6ef
authored
May 09, 2023
by
OlivierDehaene
Committed by
GitHub
May 09, 2023
Browse files
feat(server): optim flash causal lm decode_token (#285)
parent
bc5c0723
Changes
7
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
258 additions
and
140 deletions
+258
-140
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+3
-4
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+3
-4
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+3
-4
server/text_generation_server/models/flash_causal_lm.py
server/text_generation_server/models/flash_causal_lm.py
+244
-123
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+2
-2
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+1
-1
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+2
-2
No files found.
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
ad66f6ef
...
...
@@ -554,6 +554,7 @@ class FlashLlamaModel(torch.nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -575,15 +576,11 @@ class FlashLlamaModel(torch.nn.Module):
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
slice_past_index
=
None
# Get rotary cos and sin for this forward
...
...
@@ -650,6 +647,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -658,6 +656,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
ad66f6ef
...
...
@@ -617,6 +617,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -638,15 +639,11 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
slice_past_index
=
None
# Get rotary cos and sin for this forward
...
...
@@ -726,6 +723,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -734,6 +732,7 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
ad66f6ef
...
...
@@ -484,6 +484,7 @@ class FlashSantacoderModel(nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -507,15 +508,11 @@ class FlashSantacoderModel(nn.Module):
)
)
layer_past_present_indices
=
None
cu_seqlens_q
=
None
slice_past_index
=
len
(
hidden_states
)
# Decode
else
:
# Create indices from cumulative sequence lengths
layer_past_present_indices
=
cu_seqlens
[
1
:]
-
1
cu_seqlens_q
=
torch
.
arange
(
cu_seqlens
.
shape
[
0
],
dtype
=
torch
.
int32
,
device
=
hidden_states
.
device
)
slice_past_index
=
None
residual
=
None
...
...
@@ -566,6 +563,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pre_allocate_past_size
:
Optional
[
int
]
=
None
,
...
...
@@ -574,6 +572,7 @@ class FlashSantacoderForCausalLM(nn.Module):
input_ids
,
position_ids
,
cu_seqlens
,
cu_seqlens_q
,
max_s
,
past_key_values
,
pre_allocate_past_size
,
...
...
server/text_generation_server/models/flash_causal_lm.py
View file @
ad66f6ef
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/flash_llama.py
View file @
ad66f6ef
...
...
@@ -32,7 +32,7 @@ class FlashLlama(FlashCausalLM):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
...
...
@@ -161,7 +161,7 @@ class FlashLlamaSharded(FlashLlama):
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashLlama is only available on GPU"
)
...
...
server/text_generation_server/models/flash_neox.py
View file @
ad66f6ef
...
...
@@ -38,7 +38,7 @@ class FlashNeoXSharded(FlashNeoX):
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashNeoX is only available on GPU"
)
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
ad66f6ef
...
...
@@ -31,7 +31,7 @@ class FlashSantacoder(FlashCausalLM):
self
.
past_pad
=
None
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashSantacoder is only available on GPU"
)
...
...
@@ -178,7 +178,7 @@ class FlashSantacoderSharded(FlashSantacoder):
self
.
master
=
self
.
rank
==
0
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
self
.
rank
}
"
)
dtype
=
torch
.
bfloat16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float16
dtype
=
torch
.
float16
else
:
raise
NotImplementedError
(
"FlashSantacoderSharded is only available on GPU"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment