Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
dccd5c2b
Commit
dccd5c2b
authored
Nov 09, 2022
by
OlivierDehaene
Browse files
feat(server): Clarify CausalLMBatch concatenate method
parent
fa43fb71
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
86 deletions
+88
-86
Cargo.lock
Cargo.lock
+26
-26
server/text_generation/models/causal_lm.py
server/text_generation/models/causal_lm.py
+62
-60
No files found.
Cargo.lock
View file @
dccd5c2b
...
...
@@ -213,9 +213,9 @@ dependencies = [
[[package]]
name = "cc"
version = "1.0.7
3
"
version = "1.0.7
6
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
2fff2a6927b3bb87f9595d67196a70493f627687a71d87a0d692242c33f58c11
"
checksum = "
76a284da2e6fe2092f2353e51713435363112dfd60030e22add80be333fb928f
"
[[package]]
name = "cfg-if"
...
...
@@ -240,9 +240,9 @@ dependencies = [
[[package]]
name = "clap"
version = "4.0.
18
"
version = "4.0.
22
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
335867764ed2de42325fafe6d18b8af74ba97ee0c590fa016f157535b42ab04b
"
checksum = "
91b9970d7505127a162fdaa9b96428d28a479ba78c9ec7550a63a5d9863db682
"
dependencies = [
"atty",
"bitflags",
...
...
@@ -255,9 +255,9 @@ dependencies = [
[[package]]
name = "clap_derive"
version = "4.0.1
8
"
version = "4.0.
2
1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
16a1b0f6422af32d5da0c58e2703320f379216ee70198241c84173a8c5ac28f3
"
checksum = "
0177313f9f02afc995627906bbd8967e2be069f5261954222dac78290c2b9014
"
dependencies = [
"heck 0.4.0",
"proc-macro-error",
...
...
@@ -790,9 +790,9 @@ checksum = "c4a1e36c821dbe04574f602848a19f742f4fb3c98d40449f11bcad18d6b17421"
[[package]]
name = "hyper"
version = "0.14.2
0
"
version = "0.14.2
3
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0
2c929dc5c39e335a03c405292728118860721b10190d98c2a0f0efd5baafba
c"
checksum = "0
34711faac9d2166cb1baf1a2fb0b60b1f277f8492fd72176c17f3515e1abd3
c"
dependencies = [
"bytes",
"futures-channel",
...
...
@@ -898,9 +898,9 @@ dependencies = [
[[package]]
name = "ipnet"
version = "2.5.
0
"
version = "2.5.
1
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
879d54834c8c76457ef4293a689b2a8c59b076067ad77b15efafbb05f92a592b
"
checksum = "
f88c5561171189e69df9d98bcf18fd5f9558300f7ea7b801eb8a0fd748bd8745
"
[[package]]
name = "itertools"
...
...
@@ -1053,9 +1053,9 @@ checksum = "e5ce46fe64a9d73be07dcbe690a38ce1b293be448fd8ce1e6c1b8062c9f72c6a"
[[package]]
name = "native-tls"
version = "0.2.1
0
"
version = "0.2.1
1
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
fd7e2f3618557f980e0b17e8856252eee3c97fa12c54dff0ca290fb6266ca4a9
"
checksum = "
07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e
"
dependencies = [
"lazy_static",
"libc",
...
...
@@ -1103,9 +1103,9 @@ dependencies = [
[[package]]
name = "num_cpus"
version = "1.1
3.1
"
version = "1.1
4.0
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
19e64526ebdee182341572e50e9ad03965aa510cd94427a4549448f285e957a1
"
checksum = "
f6058e64324c71e02bc2b150e4f3bc8286db6c83092132ffa3f6b1eab0f9def5
"
dependencies = [
"hermit-abi",
"libc",
...
...
@@ -1125,9 +1125,9 @@ checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3"
[[package]]
name = "once_cell"
version = "1.1
5
.0"
version = "1.1
6
.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
e82dad04139b71a90c080c8463fe0dc7902db5192d939bd0950f074d014339e1
"
checksum = "
86f0b0d4bf799edbc74508c1e8bf170ff5f41238e5f8225603ca7caaae2b7860
"
[[package]]
name = "onig"
...
...
@@ -1293,9 +1293,9 @@ checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160"
[[package]]
name = "ppv-lite86"
version = "0.2.1
6
"
version = "0.2.1
7
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872
"
checksum = "
5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de
"
[[package]]
name = "proc-macro-error"
...
...
@@ -1479,9 +1479,9 @@ dependencies = [
[[package]]
name = "regex"
version = "1.
6
.0"
version = "1.
7
.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
4c4eb3267174b8c6c2f654116623910a0fef09c4753f8dd83db29c48a0df988b
"
checksum = "
e076559ef8e241f2ae3479e36f97bd5741c0330689e217ad51ce2c76808b868a
"
dependencies = [
"aho-corasick",
"memchr",
...
...
@@ -1490,9 +1490,9 @@ dependencies = [
[[package]]
name = "regex-syntax"
version = "0.6.2
7
"
version = "0.6.2
8
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
a3f87b73ce11b1619a3c6332f45341e0047173771e8b8b73f87bfeefb7b56244
"
checksum = "
456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848
"
[[package]]
name = "remove_dir_all"
...
...
@@ -1802,7 +1802,7 @@ dependencies = [
name = "text-generation-launcher"
version = "0.1.0"
dependencies = [
"clap 4.0.
18
",
"clap 4.0.
22
",
"ctrlc",
"subprocess",
"tracing",
...
...
@@ -1814,7 +1814,7 @@ name = "text-generation-router"
version = "0.1.0"
dependencies = [
"axum",
"clap 4.0.
18
",
"clap 4.0.
22
",
"futures",
"parking_lot",
"serde",
...
...
@@ -1893,9 +1893,9 @@ checksum = "cda74da7e1a664f795bb1f8a87ec406fb89a02522cf6e50620d016add6dbbf5c"
[[package]]
name = "tokenizers"
version = "0.13.
1
"
version = "0.13.
2
"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "
3d7b08ede6742d7a59d58c71da8a6fa21bedc433dca2e855e439274d08df1170
"
checksum = "
f4ff2dd291eac98dcea13e8cf7a0b28c373a90dc9210ccdab0fa9e69ee0cac69
"
dependencies = [
"aho-corasick",
"cached-path",
...
...
server/text_generation/models/causal_lm.py
View file @
dccd5c2b
...
...
@@ -148,71 +148,73 @@ class CausalLMBatch:
]
=
batch
.
attention_mask
[:,
-
batch
.
max_sequence_length
:]
for
j
,
past
in
enumerate
(
batch
.
past_key_values
):
past_keys
,
past_values
=
past
# Shenanigans to get dimensions because BLOOM outputs a past with a different shape
# BLOOM: [batch_size * num_heads, ...] vs [batch_size, num_heads, ...]
head_dim
,
padded_sequence_length
=
past
[
0
].
shape
[
-
2
:]
num_heads
=
(
past
[
0
]
.
view
(
batch
.
size
,
-
1
,
head_dim
,
padded_sequence_length
)
.
shape
[
1
]
)
# BLOOM Keys: [batch_size * num_heads, head_dim, seq_length]
# BLOOM Values: [batch_size * num_heads, seq_length, head_dim]
past_keys
=
past_keys
.
view
(
batch
.
size
,
-
1
,
*
past_keys
.
shape
[
-
2
:])
past_values
=
past_values
.
view
(
batch
.
size
,
-
1
,
*
past_values
.
shape
[
-
2
:])
# This will run only once per layer
if
j
==
len
(
past_key_values
):
past_key_values
.
append
([])
# Decoder past
for
k
,
t
in
enumerate
(
past
):
# Needed because BLOOM past shapes are not the same for keys and values
# Keys: [batch_size * num_heads, head_dim, seq_length]
# Values: [batch_size * num_heads, seq_length, head_dim]
head_dim_last
=
False
if
t
.
shape
[
-
2
]
==
head_dim
:
t
=
t
.
view
(
batch
.
size
,
num_heads
,
head_dim
,
padded_sequence_length
)
padded_t_shape
=
(
_
,
num_heads
,
head_dim
,
padded_sequence_length
=
past_keys
.
shape
padded_past_keys_shape
=
(
total_batch_size
,
num_heads
,
head_dim
,
max_sequence_length
-
1
,
)
elif
t
.
shape
[
-
1
]
==
head_dim
:
head_dim_last
=
True
t
=
t
.
view
(
batch
.
size
,
num_heads
,
padded_sequence_length
,
head_dim
)
padded_t_shape
=
(
# head_dim is last for BLOOM
if
past_values
.
shape
[
-
1
]
==
head_dim
:
past_values_head_dim_last
=
True
padded_past_values_shape
=
(
total_batch_size
,
num_heads
,
max_sequence_length
-
1
,
head_dim
,
)
elif
past_values
.
shape
[
-
2
]
==
head_dim
:
past_values_head_dim_last
=
False
padded_past_values_shape
=
padded_past_keys_shape
else
:
raise
ValueError
(
f
"shape
{
t
.
shape
}
is not valid"
)
raise
ValueError
(
f
"past_values shape
{
past_values
.
shape
}
is not valid"
)
# Initialize tensors
# This will run only once per layer and per past tensor
if
k
==
len
(
past_key_values
[
j
]):
past_key_values
[
j
].
append
(
torch
.
zeros
(
padded_t_shape
,
dtype
=
t
.
dtype
,
device
=
t
.
device
)
# This will run only once per layer
if
j
==
len
(
past_key_values
):
padded_past_keys
=
torch
.
zeros
(
padded_past_keys_shape
,
dtype
=
past_keys
.
dtype
,
device
=
past_keys
.
device
,
)
padded_past_values
=
torch
.
zeros
(
padded_past_values_shape
,
dtype
=
past_values
.
dtype
,
device
=
past_values
.
device
,
)
past_key_values
.
append
((
padded_past_keys
,
padded_past_values
))
# We slice the past keys and values to remove the padding from previous batches
if
not
head_dim_last
:
past_key_values
[
j
][
k
][
past_key_values
[
j
][
0
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:
]
=
past_keys
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
if
past_values_head_dim_last
:
past_key_values
[
j
][
1
][
start_index
:
end_index
,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
]
=
t
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
:,
]
=
past_values
[:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:]
else
:
past_key_values
[
j
][
k
][
past_key_values
[
j
][
1
][
start_index
:
end_index
,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:,
]
=
t
[:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:,
:]
-
(
batch
.
max_sequence_length
-
1
)
:,
]
=
past_values
[:,
:,
:,
-
(
batch
.
max_sequence_length
-
1
)
:]
start_index
+=
batch
.
size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment