Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
cfaa8580
Unverified
Commit
cfaa8580
authored
May 23, 2023
by
OlivierDehaene
Committed by
GitHub
May 23, 2023
Browse files
feat(server): support fp16 for t5 (#360)
Fixes #349
parent
94377efa
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
357 additions
and
6 deletions
+357
-6
integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json
...models/__snapshots__/test_t5_sharded/test_t5_sharded.json
+60
-0
integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json
...s/__snapshots__/test_t5_sharded/test_t5_sharded_load.json
+242
-0
integration-tests/models/test_flash_neox.py
integration-tests/models/test_flash_neox.py
+3
-1
integration-tests/models/test_t5_sharded.py
integration-tests/models/test_t5_sharded.py
+38
-0
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+4
-1
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+10
-4
No files found.
integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded.json
0 → 100644
View file @
cfaa8580
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"eos_token"
,
"generated_tokens"
:
7
,
"prefill"
:
[
{
"id"
:
0
,
"logprob"
:
null
,
"text"
:
"<pad>"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
3
,
"logprob"
:
-0.7001953
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
18
,
"logprob"
:
-1.1943359
,
"special"
:
false
,
"text"
:
"-"
},
{
"id"
:
26937
,
"logprob"
:
-1.2099609
,
"special"
:
false
,
"text"
:
"196"
},
{
"id"
:
3
,
"logprob"
:
-1.2451172
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
1956
,
"logprob"
:
-0.3322754
,
"special"
:
false
,
"text"
:
"°"
},
{
"id"
:
254
,
"logprob"
:
-0.19213867
,
"special"
:
false
,
"text"
:
"C"
},
{
"id"
:
1
,
"logprob"
:
-0.030151367
,
"special"
:
true
,
"text"
:
"</s>"
}
]
},
"generated_text"
:
"-196 °C"
}
integration-tests/models/__snapshots__/test_t5_sharded/test_t5_sharded_load.json
0 → 100644
View file @
cfaa8580
[
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"eos_token"
,
"generated_tokens"
:
7
,
"prefill"
:
[
{
"id"
:
0
,
"logprob"
:
null
,
"text"
:
"<pad>"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
3
,
"logprob"
:
-0.7001953
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
18
,
"logprob"
:
-1.1943359
,
"special"
:
false
,
"text"
:
"-"
},
{
"id"
:
26937
,
"logprob"
:
-1.2119141
,
"special"
:
false
,
"text"
:
"196"
},
{
"id"
:
3
,
"logprob"
:
-1.2480469
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
1956
,
"logprob"
:
-0.33203125
,
"special"
:
false
,
"text"
:
"°"
},
{
"id"
:
254
,
"logprob"
:
-0.19250488
,
"special"
:
false
,
"text"
:
"C"
},
{
"id"
:
1
,
"logprob"
:
-0.030166626
,
"special"
:
true
,
"text"
:
"</s>"
}
]
},
"generated_text"
:
"-196 °C"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"eos_token"
,
"generated_tokens"
:
7
,
"prefill"
:
[
{
"id"
:
0
,
"logprob"
:
null
,
"text"
:
"<pad>"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
3
,
"logprob"
:
-0.7001953
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
18
,
"logprob"
:
-1.1943359
,
"special"
:
false
,
"text"
:
"-"
},
{
"id"
:
26937
,
"logprob"
:
-1.2119141
,
"special"
:
false
,
"text"
:
"196"
},
{
"id"
:
3
,
"logprob"
:
-1.2480469
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
1956
,
"logprob"
:
-0.33203125
,
"special"
:
false
,
"text"
:
"°"
},
{
"id"
:
254
,
"logprob"
:
-0.19250488
,
"special"
:
false
,
"text"
:
"C"
},
{
"id"
:
1
,
"logprob"
:
-0.030166626
,
"special"
:
true
,
"text"
:
"</s>"
}
]
},
"generated_text"
:
"-196 °C"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"eos_token"
,
"generated_tokens"
:
7
,
"prefill"
:
[
{
"id"
:
0
,
"logprob"
:
null
,
"text"
:
"<pad>"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
3
,
"logprob"
:
-0.7001953
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
18
,
"logprob"
:
-1.1943359
,
"special"
:
false
,
"text"
:
"-"
},
{
"id"
:
26937
,
"logprob"
:
-1.2119141
,
"special"
:
false
,
"text"
:
"196"
},
{
"id"
:
3
,
"logprob"
:
-1.2480469
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
1956
,
"logprob"
:
-0.33203125
,
"special"
:
false
,
"text"
:
"°"
},
{
"id"
:
254
,
"logprob"
:
-0.19250488
,
"special"
:
false
,
"text"
:
"C"
},
{
"id"
:
1
,
"logprob"
:
-0.030166626
,
"special"
:
true
,
"text"
:
"</s>"
}
]
},
"generated_text"
:
"-196 °C"
},
{
"details"
:
{
"best_of_sequences"
:
null
,
"finish_reason"
:
"eos_token"
,
"generated_tokens"
:
7
,
"prefill"
:
[
{
"id"
:
0
,
"logprob"
:
null
,
"text"
:
"<pad>"
}
],
"seed"
:
null
,
"tokens"
:
[
{
"id"
:
3
,
"logprob"
:
-0.7001953
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
18
,
"logprob"
:
-1.1943359
,
"special"
:
false
,
"text"
:
"-"
},
{
"id"
:
26937
,
"logprob"
:
-1.2099609
,
"special"
:
false
,
"text"
:
"196"
},
{
"id"
:
3
,
"logprob"
:
-1.2451172
,
"special"
:
false
,
"text"
:
" "
},
{
"id"
:
1956
,
"logprob"
:
-0.3322754
,
"special"
:
false
,
"text"
:
"°"
},
{
"id"
:
254
,
"logprob"
:
-0.19213867
,
"special"
:
false
,
"text"
:
"C"
},
{
"id"
:
1
,
"logprob"
:
-0.030151367
,
"special"
:
true
,
"text"
:
"</s>"
}
]
},
"generated_text"
:
"-196 °C"
}
]
integration-tests/models/test_flash_neox.py
View file @
cfaa8580
...
@@ -36,6 +36,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
...
@@ -36,6 +36,8 @@ async def test_flash_neox_load(flash_neox, generate_load, response_snapshot):
generated_texts
=
[
r
.
generated_text
for
r
in
responses
]
generated_texts
=
[
r
.
generated_text
for
r
in
responses
]
assert
len
(
generated_texts
)
==
4
assert
len
(
generated_texts
)
==
4
assert
generated_texts
,
all
([
text
==
generated_texts
[
0
]
for
text
in
generated_texts
])
assert
generated_texts
,
all
(
[
text
==
generated_texts
[
0
]
for
text
in
generated_texts
]
)
assert
responses
==
response_snapshot
assert
responses
==
response_snapshot
integration-tests/models/test_t5_sharded.py
0 → 100644
View file @
cfaa8580
import
pytest
@
pytest
.
fixture
(
scope
=
"module"
)
def
t5_sharded_handle
(
launcher
):
with
launcher
(
"google/flan-t5-xxl"
,
num_shard
=
2
)
as
handle
:
yield
handle
@
pytest
.
fixture
(
scope
=
"module"
)
async
def
t5_sharded
(
t5_sharded_handle
):
await
t5_sharded_handle
.
health
(
240
)
return
t5_sharded_handle
.
client
@
pytest
.
mark
.
asyncio
async
def
test_t5_sharded
(
t5_sharded
,
response_snapshot
):
response
=
await
t5_sharded
.
generate
(
"Please answer the following question. What is the boiling point of Nitrogen?"
,
max_new_tokens
=
10
,
)
assert
response
==
response_snapshot
@
pytest
.
mark
.
asyncio
async
def
test_t5_sharded_load
(
t5_sharded
,
generate_load
,
response_snapshot
):
responses
=
await
generate_load
(
t5_sharded
,
"Please answer the following question. What is the boiling point of Nitrogen?"
,
max_new_tokens
=
10
,
n
=
4
,
)
assert
len
(
responses
)
==
4
assert
all
([
r
.
generated_text
==
responses
[
0
].
generated_text
for
r
in
responses
])
assert
responses
==
response_snapshot
server/text_generation_server/models/bloom.py
View file @
cfaa8580
...
@@ -160,7 +160,10 @@ class BLOOMSharded(BLOOM):
...
@@ -160,7 +160,10 @@ class BLOOMSharded(BLOOM):
# XXX: Hack for Rowlinear to add the bias only once.
# XXX: Hack for Rowlinear to add the bias only once.
if
rank
!=
0
:
if
rank
!=
0
:
tensor
=
torch
.
zeros_like
(
tensor
)
tensor
=
torch
.
zeros_like
(
tensor
)
elif
isinstance
(
module
,
TensorParallelEmbedding
)
or
name
==
"lm_head.weight"
:
elif
(
isinstance
(
module
,
TensorParallelEmbedding
)
or
name
==
"lm_head.weight"
):
size
=
slice_
.
get_shape
()[
0
]
size
=
slice_
.
get_shape
()[
0
]
block_size
=
size
//
world_size
block_size
=
size
//
world_size
start
=
rank
*
block_size
start
=
rank
*
block_size
...
...
server/text_generation_server/models/t5.py
View file @
cfaa8580
...
@@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
...
@@ -40,7 +40,7 @@ class T5Sharded(Seq2SeqLM):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
dtype
=
torch
.
b
float16
if
torch
.
cuda
.
is_bf16_supported
()
else
torch
.
float32
dtype
=
torch
.
float16
else
:
else
:
device
=
torch
.
device
(
"cpu"
)
device
=
torch
.
device
(
"cpu"
)
dtype
=
torch
.
float32
dtype
=
torch
.
float32
...
@@ -154,9 +154,15 @@ class T5Sharded(Seq2SeqLM):
...
@@ -154,9 +154,15 @@ class T5Sharded(Seq2SeqLM):
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
f
"Name
{
name
}
-- Current
{
current_parameter_tensor
.
shape
}
and got
{
tensor
.
shape
}
"
)
)
tensor
=
tensor
.
contiguous
()
.
to
(
dtype
)
tensor
=
tensor
.
contiguous
()
if
quantize
==
"bitsandbytes"
:
# See: https://github.com/huggingface/transformers/blob/1fe1e3caa44617047f149bcc0c0b566343b714a7/src/transformers/models/t5/modeling_t5.py#LL316C15-L316C71
if
module_name
.
endswith
(
"wo"
):
tensor
=
tensor
.
to
(
torch
.
float32
)
else
:
tensor
=
tensor
.
to
(
dtype
)
if
quantize
==
"bitsandbytes"
and
not
module_name
.
endswith
(
"wo"
):
if
not
HAS_BITS_AND_BYTES
:
if
not
HAS_BITS_AND_BYTES
:
raise
ImportError
(
raise
ImportError
(
"bitsandbytes is not available on your machine either because it is not installed "
"bitsandbytes is not available on your machine either because it is not installed "
...
@@ -207,7 +213,7 @@ class T5Sharded(Seq2SeqLM):
...
@@ -207,7 +213,7 @@ class T5Sharded(Seq2SeqLM):
module
.
linear
=
replace_linear
(
state
)
module
.
linear
=
replace_linear
(
state
)
elif
quantize
==
"gptq"
:
elif
quantize
==
"gptq"
and
not
module_name
.
endswith
(
"wo"
)
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"`gptq` is not implemented for now"
"`gptq` is not implemented for now"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment