Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
0ac38d33
Unverified
Commit
0ac38d33
authored
Mar 08, 2023
by
OlivierDehaene
Committed by
GitHub
Mar 08, 2023
Browse files
feat(launcher): allow parsing num_shard from CUDA_VISIBLE_DEVICES (#107)
parent
b1485e18
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
190 additions
and
18 deletions
+190
-18
clients/python/README.md
clients/python/README.md
+128
-2
clients/python/pyproject.toml
clients/python/pyproject.toml
+1
-1
clients/python/text_generation/__init__.py
clients/python/text_generation/__init__.py
+1
-1
clients/python/text_generation/client.py
clients/python/text_generation/client.py
+12
-12
launcher/src/main.rs
launcher/src/main.rs
+48
-2
No files found.
clients/python/README.md
View file @
0ac38d33
# Text Generation
The Hugging Face Text Generation Python library provides a convenient way of interfacing with a
`text-generation-inference`
instance running on your own infrastructure or on the Hugging Face Hub.
`text-generation-inference`
instance running on
[
Hugging Face Inference Endpoints
](
https://huggingface.co/inference-endpoints
)
or on the Hugging Face Hub.
## Get Started
...
...
@@ -11,7 +12,7 @@ The Hugging Face Text Generation Python library provides a convenient way of int
pip
install
text-generation
```
### Usage
###
Inference API
Usage
```
python
from
text_generation
import
InferenceAPIClient
...
...
@@ -50,3 +51,128 @@ async for response in client.generate_stream("Why is the sky blue?"):
print
(
text
)
# ' Rayleigh scattering'
```
### Hugging Fae Inference Endpoint usage
```
python
from
text_generation
import
Client
endpoint_url
=
"https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client
=
Client
(
endpoint_url
)
text
=
client
.
generate
(
"Why is the sky blue?"
).
generated_text
print
(
text
)
# ' Rayleigh scattering'
# Token Streaming
text
=
""
for
response
in
client
.
generate_stream
(
"Why is the sky blue?"
):
if
not
response
.
token
.
special
:
text
+=
response
.
token
.
text
print
(
text
)
# ' Rayleigh scattering'
```
or with the asynchronous client:
```
python
from
text_generation
import
AsyncClient
endpoint_url
=
"https://YOUR_ENDPOINT.endpoints.huggingface.cloud"
client
=
AsyncClient
(
endpoint_url
)
response
=
await
client
.
generate
(
"Why is the sky blue?"
)
print
(
response
.
generated_text
)
# ' Rayleigh scattering'
# Token Streaming
text
=
""
async
for
response
in
client
.
generate_stream
(
"Why is the sky blue?"
):
if
not
response
.
token
.
special
:
text
+=
response
.
token
.
text
print
(
text
)
# ' Rayleigh scattering'
```
### Types
```
python
# Prompt tokens
class
PrefillToken
:
# Token ID from the model tokenizer
id
:
int
# Token text
text
:
str
# Logprob
# Optional since the logprob of the first token cannot be computed
logprob
:
Optional
[
float
]
# Generated tokens
class
Token
:
# Token ID from the model tokenizer
id
:
int
# Token text
text
:
str
# Logprob
logprob
:
float
# Is the token a special token
# Can be used to ignore tokens when concatenating
special
:
bool
# Generation finish reason
class
FinishReason
(
Enum
):
# number of generated tokens == `max_new_tokens`
Length
=
"length"
# the model generated its end of sequence token
EndOfSequenceToken
=
"eos_token"
# the model generated a text included in `stop_sequences`
StopSequence
=
"stop_sequence"
# `generate` details
class
Details
:
# Generation finish reason
finish_reason
:
FinishReason
# Number of generated tokens
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
# Prompt tokens
prefill
:
List
[
PrefillToken
]
# Generated tokens
tokens
:
List
[
Token
]
# `generate` return value
class
Response
:
# Generated text
generated_text
:
str
# Generation details
details
:
Details
# `generate_stream` details
class
StreamDetails
:
# Generation finish reason
finish_reason
:
FinishReason
# Number of generated tokens
generated_tokens
:
int
# Sampling seed if sampling was activated
seed
:
Optional
[
int
]
# `generate_stream` return value
class
StreamResponse
:
# Generated token
token
:
Token
# Complete generated text
# Only available when the generation is finished
generated_text
:
Optional
[
str
]
# Generation details
# Only available when the generation is finished
details
:
Optional
[
StreamDetails
]
```
\ No newline at end of file
clients/python/pyproject.toml
View file @
0ac38d33
[tool.poetry]
name
=
"text-generation"
version
=
"0.
1
.0"
version
=
"0.
2
.0"
description
=
"Hugging Face Text Generation Python Client"
license
=
"Apache-2.0"
authors
=
[
"Olivier Dehaene <olivier@huggingface.co>"
]
...
...
clients/python/text_generation/__init__.py
View file @
0ac38d33
...
...
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
__version__
=
"0.
3.2
"
__version__
=
"0.
2.0
"
from
text_generation.client
import
Client
,
AsyncClient
from
text_generation.inference_api
import
InferenceAPIClient
,
InferenceAPIAsyncClient
clients/python/text_generation/client.py
View file @
0ac38d33
...
...
@@ -63,7 +63,7 @@ class Client:
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
watermark
ing
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
Response
:
"""
Given a prompt, generate the following text
...
...
@@ -91,7 +91,7 @@ class Client:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermark
ing
(`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
...
...
@@ -109,7 +109,7 @@ class Client:
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
watermark
=
watermark
ing
,
watermark
=
watermark
,
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
...
...
@@ -136,7 +136,7 @@ class Client:
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
watermark
ing
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
Iterator
[
StreamResponse
]:
"""
Given a prompt, generate the following stream of tokens
...
...
@@ -164,7 +164,7 @@ class Client:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermark
ing
(`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
...
...
@@ -182,7 +182,7 @@ class Client:
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
watermark
=
watermark
ing
,
watermark
=
watermark
,
)
request
=
Request
(
inputs
=
prompt
,
stream
=
True
,
parameters
=
parameters
)
...
...
@@ -268,7 +268,7 @@ class AsyncClient:
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
watermark
ing
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
Response
:
"""
Given a prompt, generate the following text asynchronously
...
...
@@ -296,7 +296,7 @@ class AsyncClient:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermark
ing
(`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
...
...
@@ -314,7 +314,7 @@ class AsyncClient:
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
watermark
=
watermark
ing
,
watermark
=
watermark
,
)
request
=
Request
(
inputs
=
prompt
,
stream
=
False
,
parameters
=
parameters
)
...
...
@@ -338,7 +338,7 @@ class AsyncClient:
temperature
:
Optional
[
float
]
=
None
,
top_k
:
Optional
[
int
]
=
None
,
top_p
:
Optional
[
float
]
=
None
,
watermark
ing
:
bool
=
False
,
watermark
:
bool
=
False
,
)
->
AsyncIterator
[
StreamResponse
]:
"""
Given a prompt, generate the following stream of tokens asynchronously
...
...
@@ -366,7 +366,7 @@ class AsyncClient:
top_p (`float`):
If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
higher are kept for generation.
watermark
ing
(`bool`):
watermark (`bool`):
Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226)
Returns:
...
...
@@ -384,7 +384,7 @@ class AsyncClient:
temperature
=
temperature
,
top_k
=
top_k
,
top_p
=
top_p
,
watermark
=
watermark
ing
,
watermark
=
watermark
,
)
request
=
Request
(
inputs
=
prompt
,
stream
=
True
,
parameters
=
parameters
)
...
...
launcher/src/main.rs
View file @
0ac38d33
...
...
@@ -23,8 +23,10 @@ struct Args {
model_id
:
String
,
#[clap(long,
env)]
revision
:
Option
<
String
>
,
#[clap(default_value
=
"1"
,
long,
env)]
num_shard
:
usize
,
#[clap(long,
env)]
sharded
:
Option
<
bool
>
,
#[clap(long,
env)]
num_shard
:
Option
<
usize
>
,
#[clap(long,
env)]
quantize
:
bool
,
#[clap(default_value
=
"128"
,
long,
env)]
...
...
@@ -80,6 +82,7 @@ fn main() -> ExitCode {
let
Args
{
model_id
,
revision
,
sharded
,
num_shard
,
quantize
,
max_concurrent_requests
,
...
...
@@ -102,6 +105,49 @@ fn main() -> ExitCode {
watermark_delta
,
}
=
args
;
// get the number of shards given `sharded` and `num_shard`
let
num_shard
=
if
let
Some
(
sharded
)
=
sharded
{
// sharded is set
match
sharded
{
// sharded is set and true
true
=>
{
match
num_shard
{
None
=>
{
// try to default to the number of available GPUs
tracing
::
info!
(
"Parsing num_shard from CUDA_VISIBLE_DEVICES"
);
let
cuda_visible_devices
=
env
::
var
(
"CUDA_VISIBLE_DEVICES"
)
.expect
(
"--num-shard and CUDA_VISIBLE_DEVICES are not set"
);
let
n_devices
=
cuda_visible_devices
.split
(
","
)
.count
();
if
n_devices
<=
1
{
panic!
(
"`sharded` is true but only found {n_devices} CUDA devices"
);
}
tracing
::
info!
(
"Sharding on {n_devices} found CUDA devices"
);
n_devices
}
Some
(
num_shard
)
=>
{
// we can't have only one shard while sharded
if
num_shard
<=
1
{
panic!
(
"`sharded` is true but `num_shard` <= 1"
);
}
num_shard
}
}
}
// sharded is set and false
false
=>
{
let
num_shard
=
num_shard
.unwrap_or
(
1
);
// we can't have more than one shard while not sharded
if
num_shard
!=
1
{
panic!
(
"`sharded` is false but `num_shard` != 1"
);
}
num_shard
}
}
}
else
{
// default to a single shard
num_shard
.unwrap_or
(
1
)
};
// Signal handler
let
running
=
Arc
::
new
(
AtomicBool
::
new
(
true
));
let
r
=
running
.clone
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment